Commit 0328b06e authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

fixes

parent 7f179833
...@@ -330,6 +330,7 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -330,6 +330,7 @@ int run_gemm_example(int argc, char* argv[])
{ {
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
} }
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
else if(data_type == "pk_int4_t") else if(data_type == "pk_int4_t")
{ {
// TODO: Add support for bhalf_t ADataType // TODO: Add support for bhalf_t ADataType
...@@ -337,6 +338,7 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -337,6 +338,7 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::pk_int4_t, ck_tile::pk_int4_t,
ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{}); ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
} }
#endif
else else
{ {
throw std::runtime_error("Unsupported data_type!"); throw std::runtime_error("Unsupported data_type!");
...@@ -360,6 +362,7 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -360,6 +362,7 @@ int run_gemm_example(int argc, char* argv[])
{ {
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{}); return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
} }
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
else if(data_type == "pk_int4_t") else if(data_type == "pk_int4_t")
{ {
// TODO: Add support for bhalf_t ADataType // TODO: Add support for bhalf_t ADataType
...@@ -367,6 +370,7 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -367,6 +370,7 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::pk_int4_t, ck_tile::pk_int4_t,
ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{}); ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
} }
#endif
else else
{ {
throw std::runtime_error("Unsupported data_type!"); throw std::runtime_error("Unsupported data_type!");
......
...@@ -60,6 +60,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -60,6 +60,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize = static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize; ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize = static constexpr index_t BPackedSize =
......
...@@ -21,6 +21,8 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -21,6 +21,8 @@ struct BaseGemmPipelineAgBgCrMem
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize = static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize; ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize = static constexpr index_t BPackedSize =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment