Commit ab5d0278 authored by kylasa's avatar kylasa Committed by Sam Wu
Browse files

Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845)



* Support bf16/fb8/bf8 datatypes for ck_tile/gemm

* remove commented out code.

* Addressing code review comments and enabling universal_gemm for all the supported data types.

* Merge conflict resolution.

* Solve the memory pipeline compilation error. Merge with the new change of CShuffle

* finish the feature, pass the tests

* Fix the pipeline and add the benchmark script for other data types

---------
Co-authored-by: default avatarThomasNing <thomas.ning@amd.com>
parent 9b5dfba2
...@@ -159,7 +159,7 @@ struct GemmKernel ...@@ -159,7 +159,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value) is_any_of<CDataType, fp16_t, bf16_t>::value)
{ {
if(kargs.k_batch != 1) if(kargs.k_batch != 1)
...@@ -240,7 +240,7 @@ struct GemmKernel ...@@ -240,7 +240,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -255,7 +255,7 @@ struct GemmKernel ...@@ -255,7 +255,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -321,7 +321,7 @@ struct GemmKernel ...@@ -321,7 +321,7 @@ struct GemmKernel
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{}, number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -519,7 +519,7 @@ struct GemmKernel ...@@ -519,7 +519,7 @@ struct GemmKernel
{ {
// Do not compile in case where we have unsupported // Do not compile in case where we have unsupported
// VectorSizeC & data type configuration. // VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)) is_any_of<CDataType, fp16_t, bf16_t>::value))
{ {
RunGemm<memory_operation_enum::atomic_add>( RunGemm<memory_operation_enum::atomic_add>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment