Commit 610f9a34 authored by Sudhir Kylasa's avatar Sudhir Kylasa
Browse files

Addressing code review comments.

parent a3678d26
......@@ -18,7 +18,7 @@
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_MEMORY
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
......
......@@ -94,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
using ADataType = typename GemmTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
......
......@@ -240,8 +240,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc"
void run_gemm_instance(std::string data_type, std::string a_layout, std::string b_layout)
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
if(data_type == "fp16")
......@@ -340,20 +351,4 @@ void run_gemm_instance(std::string data_type, std::string a_layout, std::string
}
}
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
return run_gemm_instance(data_type, a_layout, b_layout);
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
......@@ -142,7 +142,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC<ODataType>(),
GetVectorSizeC(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
......
......@@ -248,7 +248,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false;
......@@ -263,7 +263,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false;
......@@ -329,7 +329,7 @@ struct GemmKernel
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
......@@ -527,7 +527,7 @@ struct GemmKernel
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment