Commit 610f9a34 authored by Sudhir Kylasa's avatar Sudhir Kylasa
Browse files

Addressing code review comments.

parent a3678d26
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_MEMORY #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_MEMORY
#endif #endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
......
...@@ -94,10 +94,10 @@ int run_gemm_example_with_layouts(int argc, ...@@ -94,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if(!result) if(!result)
return -1; return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType; using ADataType = typename GemmTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType; using BDataType = typename GemmTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType; using CDataType = typename GemmTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType; using AccDataType = typename GemmTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
......
...@@ -240,8 +240,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -240,8 +240,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
void run_gemm_instance(std::string data_type, std::string a_layout, std::string b_layout) int run_gemm_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") if(a_layout == "R" && b_layout == "R")
{ {
if(data_type == "fp16") if(data_type == "fp16")
...@@ -340,20 +351,4 @@ void run_gemm_instance(std::string data_type, std::string a_layout, std::string ...@@ -340,20 +351,4 @@ void run_gemm_instance(std::string data_type, std::string a_layout, std::string
} }
} }
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
return run_gemm_instance(data_type, a_layout, b_layout);
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -142,7 +142,7 @@ struct CShuffleEpilogue ...@@ -142,7 +142,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D<kBlockSize, TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration, kMPerIteration,
kNPerIteration, kNPerIteration,
GetVectorSizeC<ODataType>(), GetVectorSizeC(),
tile_distribution_pattern::thread_raked>; tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
......
...@@ -248,7 +248,7 @@ struct GemmKernel ...@@ -248,7 +248,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0) if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -263,7 +263,7 @@ struct GemmKernel ...@@ -263,7 +263,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0) if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -329,7 +329,7 @@ struct GemmKernel ...@@ -329,7 +329,7 @@ struct GemmKernel
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{}, number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -527,7 +527,7 @@ struct GemmKernel ...@@ -527,7 +527,7 @@ struct GemmKernel
{ {
// Do not compile in case where we have unsupported // Do not compile in case where we have unsupported
// VectorSizeC & data type configuration. // VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 && if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)) is_any_of<CDataType, fp16_t, bf16_t>::value))
{ {
RunGemm<memory_operation_enum::atomic_add>( RunGemm<memory_operation_enum::atomic_add>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment