Unverified Commit 0fcbb25f authored by deepsek's avatar deepsek Committed by GitHub
Browse files

fix: preprocessor directives logic error if/else (#1764)

* fix: preprocessors logic error if/else

* fix: added macros as preferred by CK team
parent 54de3e55
...@@ -21,7 +21,6 @@ enum struct GemmDataType ...@@ -21,7 +21,6 @@ enum struct GemmDataType
F16_F16_F16, // 1 F16_F16_F16, // 1
F16_F8_F16, // 2 F16_F8_F16, // 2
F16_I8_F16, // 3 F16_I8_F16, // 3
}; };
#define OP_NAME "grouped_gemm_fixed_nk" #define OP_NAME "grouped_gemm_fixed_nk"
...@@ -39,7 +38,6 @@ std::vector<int> argToIntArray(char* input) ...@@ -39,7 +38,6 @@ std::vector<int> argToIntArray(char* input)
{ {
out.push_back(std::stoi(item)); out.push_back(std::stoi(item));
} }
return out; return out;
} }
...@@ -83,14 +81,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -83,14 +81,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1; const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1;
using F32 = float;
using F16 = ck::half_t;
#if defined(CK_ENABLE_FP8)
using F8 = ck::f8_t;
#endif
using BF16 = ck::bhalf_t;
using I8 = int8_t;
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
if(argc == 17) if(argc == 17)
...@@ -99,13 +89,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -99,13 +89,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter = std::stoi(argv[16]); n_iter = std::stoi(argv[16]);
} }
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
I8, ck::half_t,
BF16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -123,12 +112,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -123,12 +112,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
I8, ck::half_t,
BF16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -146,14 +135,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -146,14 +135,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #if defined(CK_ENABLE_FP8)
#if defined(CK_ENABLE_FP16) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F16, ck::f8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -171,12 +159,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -171,12 +159,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F16, ck::f8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -195,13 +183,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -195,13 +183,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter); n_iter);
} }
#endif #endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) #if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F8, int8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -219,12 +207,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -219,12 +207,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F8, int8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -238,18 +226,19 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -238,18 +226,19 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, 1,
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8) #if defined(CK_ENABLE_BF16)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN) #if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
I8, int8_t,
F16, ck::bhalf_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -267,12 +256,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -267,12 +256,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
I8, int8_t,
F16, ck::bhalf_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -286,10 +275,11 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -286,10 +275,11 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
1, kbatch,
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif
#endif #endif
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment