Commit f9181773 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Permute packed f4_t values

parent ee8937a8
...@@ -40,14 +40,14 @@ struct f4x2_pk_t ...@@ -40,14 +40,14 @@ struct f4x2_pk_t
{ {
static_assert(I < 2, "Index is out of range."); static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0) if constexpr(I == 0)
return data & 0b00001111;
else
return (data >> 4); return (data >> 4);
else
return data & 0b00001111;
} }
__host__ __device__ inline type pack(const type x0, const type x1) __host__ __device__ inline type pack(const type x0, const type x1)
{ {
return (x1 << 4) | (x0 & 0b00001111); return (x0 << 4) | (x1 & 0b00001111);
} }
}; };
......
...@@ -380,9 +380,9 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b ...@@ -380,9 +380,9 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0); return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
#else #else
float2_t ret{utils::to_float<f4_t>( float2_t ret{utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})), scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
utils::to_float<f4_t>( utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))}; scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))};
return ret; return ret;
#endif #endif
} }
......
...@@ -742,8 +742,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) ...@@ -742,8 +742,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
uint32_t bitwise; uint32_t bitwise;
f4x2_t f4x2_array[4]; f4x2_t f4x2_array[4];
} value{0}; } value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[0] / scale); uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[1] / scale); uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
value.bitwise = (h << 4) | l; value.bitwise = (h << 4) | l;
return value.f4x2_array[0]; return value.f4x2_array[0];
#endif #endif
...@@ -969,8 +969,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) ...@@ -969,8 +969,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
uint32_t bitwise; uint32_t bitwise;
f4x2_t f4x2_array[4]; f4x2_t f4x2_array[4];
} value{0}; } value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng); uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng); uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
value.bitwise = (h << 4) | l; value.bitwise = (h << 4) | l;
return value.f4x2_array[0]; return value.f4x2_array[0];
#endif #endif
......
...@@ -264,11 +264,6 @@ TEST(MXFP4, DeviceScaledConvert) ...@@ -264,11 +264,6 @@ TEST(MXFP4, DeviceScaledConvert)
device_completed.FromDevice(&completed); device_completed.FromDevice(&completed);
device_out.FromDevice(out.data()); device_out.FromDevice(out.data());
for(ck::index_t id = 0; id < 256 * 16; id++)
{
printf("%f\n", out.data()[id]);
}
// V = X * P; X - E8M0 scale, P - FP4 // V = X * P; X - E8M0 scale, P - FP4
// If X = NaN, then V = NaN regardless of P // If X = NaN, then V = NaN regardless of P
...@@ -279,32 +274,14 @@ TEST(MXFP4, DeviceScaledConvert) ...@@ -279,32 +274,14 @@ TEST(MXFP4, DeviceScaledConvert)
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
} }
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> fp4_nan_ids;
fp4_nan_ids.insert(0b11111111); //-NaN
fp4_nan_ids.insert(0b01111111); // +NaN
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto fp4_nan_id : fp4_nan_ids)
{
auto idx = exp_id * 256 + fp4_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{ {
if(exp_id == e8m0_nan_id) if(exp_id == e8m0_nan_id)
continue; continue;
for(ck::index_t fp4_id = 0; fp4_id < 256; fp4_id++) for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++)
{ {
if(fp4_nan_ids.find(fp4_id) != fp4_nan_ids.end())
continue;
uint8_t fp4_uid = static_cast<uint8_t>(fp4_id); uint8_t fp4_uid = static_cast<uint8_t>(fp4_id);
auto idx = exp_id * 256 + fp4_uid; auto idx = exp_id * 16 + fp4_uid;
ASSERT_FLOAT_EQ(out[idx], ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) * type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(f4_t(fp4_uid & 0b00001111))) type_convert<float>(f4_t(fp4_uid & 0b00001111)))
...@@ -319,19 +296,19 @@ TEST(MXFP4, DeviceScaledConvert) ...@@ -319,19 +296,19 @@ TEST(MXFP4, DeviceScaledConvert)
auto i = 256 * 16; auto i = 256 * 16;
// f4x2 -> f32x2 // f4x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -5.0f)); EXPECT_EQ(out[i++], 1.0f);
EXPECT_EQ(out[i++], powf(2.0f, -8.0f)); EXPECT_EQ(out[i++], -4.0f);
// f32x2 -> f4x2 // f32x2 -> f4x2
// RNE // RNE
EXPECT_EQ(out[i++], -4.0f); EXPECT_EQ(out[i++], 0.5f);
EXPECT_EQ(out[i++], 2.0f); EXPECT_EQ(out[i++], -2.0f);
// SR // SR
EXPECT_EQ(out[i++], 0.5f);
EXPECT_EQ(out[i++], -2.0f); EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even /// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1]; EXPECT_EQ(out[i++], 24.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#if 1 #if 1
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
...@@ -347,7 +324,7 @@ TEST(MXFP4, DeviceScaledConvert) ...@@ -347,7 +324,7 @@ TEST(MXFP4, DeviceScaledConvert)
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Lowest())) EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Lowest()))
<< "out[i-1]: " << out[i - 1]; << "out[i-1]: " << out[i - 1];
#endif #endif
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f4_t>(312.5f))) EXPECT_EQ(out[i++], type_convert<float>(type_convert<f4_t>(5.0f)))
<< "out[i-1]: " << out[i - 1]; << "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed); EXPECT_EQ(test_size, completed);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment