Commit 668dda6f authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Fix pool tests for OCP FP8 data type

parent 97042d87
...@@ -1063,7 +1063,7 @@ struct non_native_vector_base< ...@@ -1063,7 +1063,7 @@ struct non_native_vector_base<
StaticallyIndexedArray<data_v, 1> dNx1; StaticallyIndexedArray<data_v, 1> dNx1;
} data_; } data_;
__host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v{a}} {} __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {}
__host__ __device__ constexpr non_native_vector_base(T f) __host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f)) : non_native_vector_base(bit_cast<data_t>(f))
{ {
......
...@@ -915,6 +915,10 @@ TEST(FP8OCP, TestAsType) ...@@ -915,6 +915,10 @@ TEST(FP8OCP, TestAsType)
ASSERT_EQ(left_vec.template AsType<f8_t>()(Number<i>{}), ASSERT_EQ(left_vec.template AsType<f8_t>()(Number<i>{}),
ck::type_convert<f8_t>(test_vec.at(i))); ck::type_convert<f8_t>(test_vec.at(i)));
}); });
ck::non_native_vector_base<ck::f8_ocp_t, 2> nnvb_f8x2(ck::type_convert<f8_t>(-10.0f));
ASSERT_EQ(nnvb_f8x2.template AsType<f8_t>()(Number<0>{}), ck::type_convert<f8_t>(-10.0f));
ASSERT_EQ(nnvb_f8x2.template AsType<f8_t>()(Number<1>{}), ck::type_convert<f8_t>(-10.0f));
} }
TEST(FP8OCP, TestAsTypeReshape) TEST(FP8OCP, TestAsTypeReshape)
...@@ -988,6 +992,10 @@ TEST(BF8OCP, TestAsType) ...@@ -988,6 +992,10 @@ TEST(BF8OCP, TestAsType)
ASSERT_EQ(left_vec.template AsType<bf8_t>()(Number<i>{}), ASSERT_EQ(left_vec.template AsType<bf8_t>()(Number<i>{}),
ck::type_convert<bf8_t>(test_vec.at(i))); ck::type_convert<bf8_t>(test_vec.at(i)));
}); });
ck::non_native_vector_base<bf8_t, 2> nnvb_bf8x2(ck::type_convert<bf8_t>(-10.0f));
ASSERT_EQ(nnvb_bf8x2.template AsType<bf8_t>()(Number<0>{}), ck::type_convert<bf8_t>(-10.0f));
ASSERT_EQ(nnvb_bf8x2.template AsType<bf8_t>()(Number<1>{}), ck::type_convert<bf8_t>(-10.0f));
} }
TEST(BF8OCP, TestAsTypeReshape) TEST(BF8OCP, TestAsTypeReshape)
......
...@@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types); ...@@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types); TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types);
TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types); TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types);
TYPED_TEST(AvgPool2D_F32, AvgPool2D_I8_Test) { this->Run(); } TYPED_TEST(AvgPool2D_F32, AvgPool2D_F32_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); } TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); } TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); } TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); }
......
...@@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types); ...@@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types); TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types);
TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types); TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types);
TYPED_TEST(MaxPool2D_F32, MaxPool2D_I8_Test) { this->Run(); } TYPED_TEST(MaxPool2D_F32, MaxPool2D_F32_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); } TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); } TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); } TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment