"driver/conv_driver.cpp" did not exist on "1566b31736d191fe3a43dd5efa59968e44191729"
Commit a3678d26 authored by Sudhir Kylasa's avatar Sudhir Kylasa
Browse files

Addressing code review comments.

parent 8086bbe3
......@@ -15,7 +15,7 @@
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_MEMORY
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
......@@ -31,10 +31,10 @@
#endif
template <typename DataType>
struct GemmBasicTypeConfig;
struct GemmTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
struct GemmTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
......@@ -44,7 +44,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t>
struct GemmTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
......@@ -53,7 +53,7 @@ struct GemmBasicTypeConfig<ck_tile::bf16_t>
};
template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t>
struct GemmTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
......@@ -62,7 +62,7 @@ struct GemmBasicTypeConfig<ck_tile::fp8_t>
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t>
struct GemmTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
......
......@@ -240,19 +240,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[])
void run_gemm_instance(std::string data_type, std::string a_layout, std::string b_layout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
if(data_type == "fp16")
......@@ -351,4 +340,20 @@ int run_gemm_example(int argc, char* argv[])
}
}
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
return run_gemm_instance(data_type, a_layout, b_layout);
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
......@@ -77,7 +77,6 @@ struct CShuffleEpilogue
*
* @return The vector store size for C tensor.
*/
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
......
......@@ -167,7 +167,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment