Commit f9181773 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Permute packed f4_t values

parent ee8937a8
......@@ -40,14 +40,14 @@ struct f4x2_pk_t
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0)
return data & 0b00001111;
else
return (data >> 4);
else
return data & 0b00001111;
}
__host__ __device__ inline type pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
return (x0 << 4) | (x1 & 0b00001111);
}
};
......
......@@ -380,9 +380,9 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
#else
float2_t ret{utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))};
return ret;
#endif
}
......
......@@ -742,8 +742,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[0] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
......@@ -969,8 +969,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
......
......@@ -264,11 +264,6 @@ TEST(MXFP4, DeviceScaledConvert)
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
for(ck::index_t id = 0; id < 256 * 16; id++)
{
printf("%f\n", out.data()[id]);
}
// V = X * P; X - E8M0 scale, P - FP4
// If X = NaN, then V = NaN regardless of P
......@@ -279,32 +274,14 @@ TEST(MXFP4, DeviceScaledConvert)
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> fp4_nan_ids;
fp4_nan_ids.insert(0b11111111); //-NaN
fp4_nan_ids.insert(0b01111111); // +NaN
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto fp4_nan_id : fp4_nan_ids)
{
auto idx = exp_id * 256 + fp4_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t fp4_id = 0; fp4_id < 256; fp4_id++)
for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++)
{
if(fp4_nan_ids.find(fp4_id) != fp4_nan_ids.end())
continue;
uint8_t fp4_uid = static_cast<uint8_t>(fp4_id);
auto idx = exp_id * 256 + fp4_uid;
auto idx = exp_id * 16 + fp4_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(f4_t(fp4_uid & 0b00001111)))
......@@ -319,19 +296,19 @@ TEST(MXFP4, DeviceScaledConvert)
auto i = 256 * 16;
// f4x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -5.0f));
EXPECT_EQ(out[i++], powf(2.0f, -8.0f));
EXPECT_EQ(out[i++], 1.0f);
EXPECT_EQ(out[i++], -4.0f);
// f32x2 -> f4x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(out[i++], 0.5f);
EXPECT_EQ(out[i++], -2.0f);
// SR
EXPECT_EQ(out[i++], 0.5f);
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], 24.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#if 1
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
......@@ -347,7 +324,7 @@ TEST(MXFP4, DeviceScaledConvert)
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
#endif
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f4_t>(312.5f)))
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f4_t>(5.0f)))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment