Commit 83fcce27 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update test vectors

parent e38b4a33
......@@ -326,18 +326,21 @@ TEST(MXFP4, DeviceScaledConvert)
__host__ __device__ float vec16_generator(ck::index_t i)
{
return (i < 8 ? -1.0 : 1.0) * powf(2.0f, i % 8);
return type_convert<float>(f4_t(i & 0b00001111));
}
__host__ __device__ float vec32_generator(ck::index_t i)
{
if(i < 16)
{
return vec16_generator(i % 16);
return vec16_generator(
i); // all positive values, then all negative values in ascending order
}
else
{
return 1.5f * vec16_generator(i % 16);
return type_convert<float>(ck::NumericLimits<f4_t>::Max()) -
vec16_generator(
i); // all positive values, then all negative values in descending order
}
}
......@@ -394,9 +397,11 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvert)
device_out.FromDevice(out.data());
auto i = 0;
auto scale2 = e8m0_bexp_t(2.0f);
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
<< "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
......@@ -456,9 +461,11 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvertSR)
device_out.FromDevice(out.data());
auto i = 0;
auto scale2 = e8m0_bexp_t(2.0f);
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl;
EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
<< "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
......@@ -481,14 +488,14 @@ __global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_
return;
}
auto scale2 = e8m0_bexp_t(4.0f);
auto scale2 = e8m0_bexp_t(2.0f);
f4x32_t f4x32{};
float32_t float32{};
ck::static_for<0, N / 2, 1>{}([&](auto ii) {
f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{})
.pack(type_convert<f4_t>(vec32_generator(2 * ii) / 16.0f),
type_convert<f4_t>(vec32_generator(2 * ii + 1) / 16.0f));
.pack(type_convert<f4_t>(vec32_generator(2 * ii) / type_convert<float>(scale2)),
type_convert<f4_t>(vec32_generator(2 * ii + 1) / type_convert<float>(scale2)));
});
float32 = scaled_type_convert<float32_t>(scale2, f4x32);
......@@ -516,9 +523,11 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert)
device_out.FromDevice(out.data());
auto i = 0;
auto scale2 = e8m0_bexp_t(2.0f);
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl;
EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
<< "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
......
......@@ -85,15 +85,9 @@ TEST(MXFP4, FP4ToFP32)
std::vector<float> out(2, -1.0f);
DeviceMem device_out(2 * sizeof(float));
// DeviceMem device_completed(sizeof(uint64_t));
// device_out.SetValue(-21.0f);
// device_completed.SetValue(-21.0f);
run_test_mx_fp4_to_fp32<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// f4x2 -> f32x2
......@@ -106,12 +100,9 @@ TEST(MXFP4, FP32ToFP4RNE)
std::vector<float> out(2, -1.0f);
DeviceMem device_out(2 * sizeof(float));
// DeviceMem device_completed(sizeof(uint64_t));
run_test_mx_fp32_to_fp4_rne<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// f32x2 -> f4x2
......@@ -125,12 +116,9 @@ TEST(MXFP4, FP32ToFP4SR)
std::vector<float> out(2, -1.0f);
DeviceMem device_out(2 * sizeof(float));
// DeviceMem device_completed(sizeof(uint64_t));
run_test_mx_fp32_to_fp4_sr<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// SR
......@@ -143,12 +131,9 @@ TEST(MXFP4, FP32ToFP4SRFailing)
std::vector<float> out(2, -1.0f);
DeviceMem device_out(2 * sizeof(float));
// DeviceMem device_completed(sizeof(uint64_t));
run_test_mx_fp32_to_fp4_sr_failing<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// SR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment