Commit 83fcce27 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update test vectors

parent e38b4a33
...@@ -326,18 +326,21 @@ TEST(MXFP4, DeviceScaledConvert) ...@@ -326,18 +326,21 @@ TEST(MXFP4, DeviceScaledConvert)
__host__ __device__ float vec16_generator(ck::index_t i) __host__ __device__ float vec16_generator(ck::index_t i)
{ {
return (i < 8 ? -1.0 : 1.0) * powf(2.0f, i % 8); return type_convert<float>(f4_t(i & 0b00001111));
} }
__host__ __device__ float vec32_generator(ck::index_t i) __host__ __device__ float vec32_generator(ck::index_t i)
{ {
if(i < 16) if(i < 16)
{ {
return vec16_generator(i % 16); return vec16_generator(
i); // all positive values, then all negative values in ascending order
} }
else else
{ {
return 1.5f * vec16_generator(i % 16); return type_convert<float>(ck::NumericLimits<f4_t>::Max()) -
vec16_generator(
i); // all positive values, then all negative values in descending order
} }
} }
...@@ -394,9 +397,11 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvert) ...@@ -394,9 +397,11 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvert)
device_out.FromDevice(out.data()); device_out.FromDevice(out.data());
auto i = 0; auto i = 0;
auto scale2 = e8m0_bexp_t(2.0f);
ck::static_for<0, N, 1>{}([&](auto ii) { ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl; EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
<< "ii: " << ii << std::endl;
}); });
EXPECT_EQ(N, completed); EXPECT_EQ(N, completed);
...@@ -456,9 +461,11 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvertSR) ...@@ -456,9 +461,11 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvertSR)
device_out.FromDevice(out.data()); device_out.FromDevice(out.data());
auto i = 0; auto i = 0;
auto scale2 = e8m0_bexp_t(2.0f);
ck::static_for<0, N, 1>{}([&](auto ii) { ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl; EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
<< "ii: " << ii << std::endl;
}); });
EXPECT_EQ(N, completed); EXPECT_EQ(N, completed);
...@@ -481,14 +488,14 @@ __global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_ ...@@ -481,14 +488,14 @@ __global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_
return; return;
} }
auto scale2 = e8m0_bexp_t(4.0f); auto scale2 = e8m0_bexp_t(2.0f);
f4x32_t f4x32{}; f4x32_t f4x32{};
float32_t float32{}; float32_t float32{};
ck::static_for<0, N / 2, 1>{}([&](auto ii) { ck::static_for<0, N / 2, 1>{}([&](auto ii) {
f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}) f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{})
.pack(type_convert<f4_t>(vec32_generator(2 * ii) / 16.0f), .pack(type_convert<f4_t>(vec32_generator(2 * ii) / type_convert<float>(scale2)),
type_convert<f4_t>(vec32_generator(2 * ii + 1) / 16.0f)); type_convert<f4_t>(vec32_generator(2 * ii + 1) / type_convert<float>(scale2)));
}); });
float32 = scaled_type_convert<float32_t>(scale2, f4x32); float32 = scaled_type_convert<float32_t>(scale2, f4x32);
...@@ -516,9 +523,11 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert) ...@@ -516,9 +523,11 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert)
device_out.FromDevice(out.data()); device_out.FromDevice(out.data());
auto i = 0; auto i = 0;
auto scale2 = e8m0_bexp_t(2.0f);
ck::static_for<0, N, 1>{}([&](auto ii) { ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl; EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
<< "ii: " << ii << std::endl;
}); });
EXPECT_EQ(N, completed); EXPECT_EQ(N, completed);
......
...@@ -85,15 +85,9 @@ TEST(MXFP4, FP4ToFP32) ...@@ -85,15 +85,9 @@ TEST(MXFP4, FP4ToFP32)
std::vector<float> out(2, -1.0f); std::vector<float> out(2, -1.0f);
DeviceMem device_out(2 * sizeof(float)); DeviceMem device_out(2 * sizeof(float));
// DeviceMem device_completed(sizeof(uint64_t));
// device_out.SetValue(-21.0f);
// device_completed.SetValue(-21.0f);
run_test_mx_fp4_to_fp32<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer())); run_test_mx_fp4_to_fp32<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out.FromDevice(out.data()); device_out.FromDevice(out.data());
// f4x2 -> f32x2 // f4x2 -> f32x2
...@@ -106,12 +100,9 @@ TEST(MXFP4, FP32ToFP4RNE) ...@@ -106,12 +100,9 @@ TEST(MXFP4, FP32ToFP4RNE)
std::vector<float> out(2, -1.0f); std::vector<float> out(2, -1.0f);
DeviceMem device_out(2 * sizeof(float)); DeviceMem device_out(2 * sizeof(float));
// DeviceMem device_completed(sizeof(uint64_t));
run_test_mx_fp32_to_fp4_rne<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer())); run_test_mx_fp32_to_fp4_rne<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out.FromDevice(out.data()); device_out.FromDevice(out.data());
// f32x2 -> f4x2 // f32x2 -> f4x2
...@@ -125,12 +116,9 @@ TEST(MXFP4, FP32ToFP4SR) ...@@ -125,12 +116,9 @@ TEST(MXFP4, FP32ToFP4SR)
std::vector<float> out(2, -1.0f); std::vector<float> out(2, -1.0f);
DeviceMem device_out(2 * sizeof(float)); DeviceMem device_out(2 * sizeof(float));
// DeviceMem device_completed(sizeof(uint64_t));
run_test_mx_fp32_to_fp4_sr<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer())); run_test_mx_fp32_to_fp4_sr<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out.FromDevice(out.data()); device_out.FromDevice(out.data());
// SR // SR
...@@ -143,12 +131,9 @@ TEST(MXFP4, FP32ToFP4SRFailing) ...@@ -143,12 +131,9 @@ TEST(MXFP4, FP32ToFP4SRFailing)
std::vector<float> out(2, -1.0f); std::vector<float> out(2, -1.0f);
DeviceMem device_out(2 * sizeof(float)); DeviceMem device_out(2 * sizeof(float));
// DeviceMem device_completed(sizeof(uint64_t));
run_test_mx_fp32_to_fp4_sr_failing<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer())); run_test_mx_fp32_to_fp4_sr_failing<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out.FromDevice(out.data()); device_out.FromDevice(out.data());
// SR // SR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment