Commit a3678d26 authored by Sudhir Kylasa's avatar Sudhir Kylasa
Browse files

Addressing code review comments.

parent 8086bbe3
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT #ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_MEMORY
#endif #endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
...@@ -31,10 +31,10 @@ ...@@ -31,10 +31,10 @@
#endif #endif
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmTypeConfig;
template <> template <>
struct GemmBasicTypeConfig<ck_tile::half_t> struct GemmTypeConfig<ck_tile::half_t>
{ {
using ADataType = ck_tile::half_t; using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t; using BDataType = ck_tile::half_t;
...@@ -44,7 +44,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t> ...@@ -44,7 +44,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
}; };
template <> template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t> struct GemmTypeConfig<ck_tile::bf16_t>
{ {
using ADataType = ck_tile::bf16_t; using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t;
...@@ -53,7 +53,7 @@ struct GemmBasicTypeConfig<ck_tile::bf16_t> ...@@ -53,7 +53,7 @@ struct GemmBasicTypeConfig<ck_tile::bf16_t>
}; };
template <> template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t> struct GemmTypeConfig<ck_tile::fp8_t>
{ {
using ADataType = ck_tile::fp8_t; using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t; using BDataType = ck_tile::fp8_t;
...@@ -62,7 +62,7 @@ struct GemmBasicTypeConfig<ck_tile::fp8_t> ...@@ -62,7 +62,7 @@ struct GemmBasicTypeConfig<ck_tile::fp8_t>
}; };
template <> template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t> struct GemmTypeConfig<ck_tile::bf8_t>
{ {
using ADataType = ck_tile::bf8_t; using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t; using BDataType = ck_tile::bf8_t;
......
...@@ -240,19 +240,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -240,19 +240,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[]) void run_gemm_instance(std::string data_type, std::string a_layout, std::string b_layout)
{ {
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") if(a_layout == "R" && b_layout == "R")
{ {
if(data_type == "fp16") if(data_type == "fp16")
...@@ -351,4 +340,20 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -351,4 +340,20 @@ int run_gemm_example(int argc, char* argv[])
} }
} }
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
return run_gemm_instance(data_type, a_layout, b_layout);
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -77,7 +77,6 @@ struct CShuffleEpilogue ...@@ -77,7 +77,6 @@ struct CShuffleEpilogue
* *
* @return The vector store size for C tensor. * @return The vector store size for C tensor.
*/ */
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{ {
constexpr index_t MaxVectorStoreSize = 16; constexpr index_t MaxVectorStoreSize = 16;
......
...@@ -167,7 +167,7 @@ struct GemmKernel ...@@ -167,7 +167,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 && if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value) is_any_of<CDataType, fp16_t, bf16_t>::value)
{ {
if(kargs.k_batch != 1) if(kargs.k_batch != 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment