"...composable_kernel_rocm.git" did not exist on "72c9f129cd7b918f71a966a1dfc3c83bb31bd78c"
Commit 24faa1fc authored by aska-0096's avatar aska-0096
Browse files

Add f32_16x16x16_bf16 unit test

parent 790e21ec
......@@ -55,10 +55,14 @@ int main(int, char*[])
bool pass = true;
// clang-format off
// |SrcType |DstType |GPUAccType |CPUAccType |AccNum
pass &= run_test<ck::half_t, float, float, float, 8 >();
pass &= run_test<ck::half_t, ck::half_t, float, float, 8 >();
pass &= run_test<ck::bhalf_t, ck::bhalf_t, float, float, 8 >();
pass &= run_test<ck::half_t, ck::half_t, ck::half_t, ck::half_t, 16 >();
pass &= run_test<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, float, 16 >();
pass &= run_test<int8_t, int8_t, int32_t, int32_t, 8 >();
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
pass &= run_test<int4_t, int4_t, int32_t, int32_t, 8 >();
#endif
// clang-format on
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
......
......@@ -32,6 +32,18 @@ builtin_wmma_naive_selector<half16_t,
reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{}));
}
template <>
__device__ void
builtin_wmma_naive_selector<bhalf16_t,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, float, 1, 8, true>>(
const bhalf16_t& reg_a,
const bhalf16_t& reg_b,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, float, 1, 8, true>& reg_c)
{
intrin_wmma_f32_16x16x16_bf16_w32<16, 16>::Run(
reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{}));
}
template <>
__device__ void
builtin_wmma_naive_selector<half16_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment