Commit 7bae1691 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Adapt the new GridwiseGemm interface

parent ceebf306
...@@ -389,6 +389,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -389,6 +389,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ALayout,
BLayout,
CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -396,10 +399,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -396,10 +399,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -434,8 +435,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -434,8 +435,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
LoopSched>; LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public GridwiseGemm::Argument
{ {
using Parent = typename GridwiseGemm::Argument;
Argument(const ADataType* p_a_grid_real, Argument(const ADataType* p_a_grid_real,
const ADataType* p_a_grid_imag, const ADataType* p_a_grid_imag,
const BDataType* p_b_grid_real, const BDataType* p_b_grid_real,
...@@ -443,55 +446,53 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -443,55 +446,53 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_real, CDataType* p_c_grid_real,
CDataType* p_c_grid_imag, CDataType* p_c_grid_imag,
CDataType* p_workspace, CDataType* p_workspace,
index_t MRaw, index_t M_,
index_t NRaw, index_t N_,
index_t KRaw, index_t K_,
index_t StrideA, index_t StrideA_,
index_t StrideB, index_t StrideB_,
index_t StrideC, index_t StrideC_,
AElementwiseOperation a_element_op, index_t MPadded_,
BElementwiseOperation b_element_op, index_t NPadded_,
CElementwiseOperation c_element_op) index_t KPadded_,
: p_a_grid_real_{p_a_grid_real}, index_t AK0_,
index_t BK0_)
: Parent(nullptr,
nullptr,
nullptr,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
MPadded_,
NPadded_,
KPadded_,
AK0_,
BK0_),
p_a_grid_real_{p_a_grid_real},
p_a_grid_imag_{p_a_grid_imag}, p_a_grid_imag_{p_a_grid_imag},
p_b_grid_real_{p_b_grid_real}, p_b_grid_real_{p_b_grid_real},
p_b_grid_imag_{p_b_grid_imag}, p_b_grid_imag_{p_b_grid_imag},
p_c_grid_real_{p_c_grid_real}, p_c_grid_real_{p_c_grid_real},
p_c_grid_imag_{p_c_grid_imag}, p_c_grid_imag_{p_c_grid_imag},
p_aux_grid_{p_workspace}, p_aux_grid_{p_workspace}
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_));
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_);
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
c_grid_desc_m_ = c_grid_desc_m_ =
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize); DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
c_grid_desc_m_ = c_grid_desc_m_ =
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
} }
p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize(); p_aux_2_grid_ = p_workspace + Parent::c_grid_desc_m_n.GetElementSpaceSize();
} }
// private: // private:
...@@ -503,38 +504,32 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -503,38 +504,32 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_imag_; CDataType* p_c_grid_imag_;
CDataType* p_aux_grid_; CDataType* p_aux_grid_;
CDataType* p_aux_2_grid_; CDataType* p_aux_2_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M c_grid_desc_m_; CGridDesc_M c_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
}; };
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceOp::Argument; // void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, Argument karg = arg;
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, if(stream_config.log_level_ > 0)
arg.block_2_ctile_map_)) {
// Print(karg);
}
if(!GridwiseGemm::CheckValidity(karg))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
const index_t grid_size = index_t gdx, gdy, gdz;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
const auto K = const auto K = GridwiseGemm::CalculateAK0(karg.K) * AK1;
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0; float ave_time = 0;
...@@ -578,224 +573,114 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -578,224 +573,114 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>;
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype karg.p_a_grid = karg.p_a_grid_real_;
CDataType, karg.p_b_grid = karg.p_b_grid_real_;
AElementwiseOperation, karg.p_c_grid = karg.p_aux_grid_;
BElementwiseOperation, ave_time += launch_and_time_kernel(
CElementwiseOperation, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, karg.p_a_grid = karg.p_a_grid_imag_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, karg.p_b_grid = karg.p_b_grid_imag_;
typename GridwiseGemm::DefaultBlock2CTileMap, karg.p_c_grid = karg.p_aux_2_grid_;
true>; ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_real_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_imag_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_real = aux - aux_2 // c_real = aux - aux_2
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config, stream_config,
subtract_kernel, subtract_kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(arg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_), make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(arg.p_aux_2_grid_)), const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(arg.p_c_grid_real_), make_tuple(karg.p_c_grid_real_),
Subtract{}); Subtract{});
ave_time += karg.p_a_grid = karg.p_a_grid_real_;
launch_and_time_kernel(stream_config, karg.p_b_grid = karg.p_b_grid_imag_;
kernel, karg.p_c_grid = karg.p_aux_grid_;
dim3(grid_size), ave_time += launch_and_time_kernel(
dim3(BlockSize), stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
0,
arg.p_a_grid_real_, karg.p_a_grid = karg.p_a_grid_imag_;
arg.p_b_grid_imag_, karg.p_b_grid = karg.p_b_grid_real_;
arg.p_aux_grid_, karg.p_c_grid = karg.p_aux_2_grid_;
arg.a_element_op_, ave_time += launch_and_time_kernel(
arg.b_element_op_, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_real_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_imag = aux + aux_2 // c_imag = aux + aux_2
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config, stream_config,
add_kernel, add_kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(arg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_), make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(arg.p_aux_2_grid_)), const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(arg.p_c_grid_imag_), make_tuple(karg.p_c_grid_imag_),
Add{}); Add{});
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>;
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype karg.p_a_grid = karg.p_a_grid_real_;
CDataType, karg.p_b_grid = karg.p_b_grid_real_;
AElementwiseOperation, karg.p_c_grid = karg.p_aux_grid_;
BElementwiseOperation, ave_time += launch_and_time_kernel(
CElementwiseOperation, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, karg.p_a_grid = karg.p_a_grid_imag_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, karg.p_b_grid = karg.p_b_grid_imag_;
typename GridwiseGemm::DefaultBlock2CTileMap, karg.p_c_grid = karg.p_aux_2_grid_;
false>; ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_real_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_imag_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_real = aux - aux_2 // c_real = aux - aux_2
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config, stream_config,
subtract_kernel, subtract_kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(arg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_), make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(arg.p_aux_2_grid_)), const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(arg.p_c_grid_real_), make_tuple(karg.p_c_grid_real_),
Subtract{}); Subtract{});
ave_time += karg.p_a_grid = karg.p_a_grid_real_;
launch_and_time_kernel(stream_config, karg.p_b_grid = karg.p_b_grid_imag_;
kernel, karg.p_c_grid = karg.p_aux_grid_;
dim3(grid_size), ave_time += launch_and_time_kernel(
dim3(BlockSize), stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
0,
arg.p_a_grid_real_, karg.p_a_grid = karg.p_a_grid_imag_;
arg.p_b_grid_imag_, karg.p_b_grid = karg.p_b_grid_real_;
arg.p_aux_grid_, karg.p_c_grid = karg.p_aux_2_grid_;
arg.a_element_op_, ave_time += launch_and_time_kernel(
arg.b_element_op_, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_real_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_imag = aux + aux_2 // c_imag = aux + aux_2
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config, stream_config,
add_kernel, add_kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(arg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_), make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(arg.p_aux_2_grid_)), const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(arg.p_c_grid_imag_), make_tuple(karg.p_c_grid_imag_),
Add{}); Add{});
} }
...@@ -816,12 +701,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -816,12 +701,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& karg)
{ {
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(karg);
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
} }
// polymorphic // polymorphic
...@@ -837,15 +719,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -837,15 +719,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_real, CDataType* p_c_real,
CDataType* p_c_imag, CDataType* p_c_imag,
CDataType* p_workspace, CDataType* p_workspace,
index_t MRaw, index_t M,
index_t NRaw, index_t N,
index_t KRaw, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CElementwiseOperation c_element_op) CElementwiseOperation)
{ {
return Argument{p_a_real, return Argument{p_a_real,
p_a_imag, p_a_imag,
...@@ -854,15 +736,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -854,15 +736,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_real, p_c_real,
p_c_imag, p_c_imag,
p_workspace, p_workspace,
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, GridwiseGemm::CalculateMPadded(M),
b_element_op, GridwiseGemm::CalculateNPadded(N),
c_element_op}; GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K)};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -875,15 +759,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -875,15 +759,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
void* p_c_real, void* p_c_real,
void* p_c_imag, void* p_c_imag,
void* p_workspace, void* p_workspace,
index_t MRaw, index_t M,
index_t NRaw, index_t N,
index_t KRaw, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CElementwiseOperation c_element_op, CElementwiseOperation,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real),
...@@ -893,15 +777,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -893,15 +777,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static_cast<CDataType*>(p_c_real), static_cast<CDataType*>(p_c_real),
static_cast<CDataType*>(p_c_imag), static_cast<CDataType*>(p_c_imag),
static_cast<CDataType*>(p_workspace), static_cast<CDataType*>(p_workspace),
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, GridwiseGemm::CalculateMPadded(M),
b_element_op, GridwiseGemm::CalculateNPadded(N),
c_element_op); GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K));
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment