Commit 7bae1691 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Adapt the new GridwiseGemm interface

parent ceebf306
......@@ -389,6 +389,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ALayout,
BLayout,
CLayout,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
......@@ -396,10 +399,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -434,8 +435,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
LoopSched>;
// Argument
struct Argument : public BaseArgument
struct Argument : public GridwiseGemm::Argument
{
using Parent = typename GridwiseGemm::Argument;
Argument(const ADataType* p_a_grid_real,
const ADataType* p_a_grid_imag,
const BDataType* p_b_grid_real,
......@@ -443,55 +446,53 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_real,
CDataType* p_c_grid_imag,
CDataType* p_workspace,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_real_{p_a_grid_real},
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t AK0_,
index_t BK0_)
: Parent(nullptr,
nullptr,
nullptr,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
MPadded_,
NPadded_,
KPadded_,
AK0_,
BK0_),
p_a_grid_real_{p_a_grid_real},
p_a_grid_imag_{p_a_grid_imag},
p_b_grid_real_{p_b_grid_real},
p_b_grid_imag_{p_b_grid_imag},
p_c_grid_real_{p_c_grid_real},
p_c_grid_imag_{p_c_grid_imag},
p_aux_grid_{p_workspace},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
p_aux_grid_{p_workspace}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_);
const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_));
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
c_grid_desc_m_ =
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize);
DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize);
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
c_grid_desc_m_ =
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
}
p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize();
p_aux_2_grid_ = p_workspace + Parent::c_grid_desc_m_n.GetElementSpaceSize();
}
// private:
......@@ -503,38 +504,32 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_imag_;
CDataType* p_aux_grid_;
CDataType* p_aux_2_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M c_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
// void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
Argument karg = arg;
if(stream_config.log_level_ > 0)
{
// Print(karg);
}
if(!GridwiseGemm::CheckValidity(karg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const auto K = GridwiseGemm::CalculateAK0(karg.K) * AK1;
float ave_time = 0;
......@@ -578,224 +573,114 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_real_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_imag_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>;
karg.p_a_grid = karg.p_a_grid_real_;
karg.p_b_grid = karg.p_b_grid_real_;
karg.p_c_grid = karg.p_aux_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
karg.p_a_grid = karg.p_a_grid_imag_;
karg.p_b_grid = karg.p_b_grid_imag_;
karg.p_c_grid = karg.p_aux_2_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(
stream_config,
subtract_kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
make_tuple(arg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
const_cast<const CDataType*>(arg.p_aux_2_grid_)),
make_tuple(arg.p_c_grid_real_),
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(karg.p_c_grid_real_),
Subtract{});
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_imag_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_real_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
karg.p_a_grid = karg.p_a_grid_real_;
karg.p_b_grid = karg.p_b_grid_imag_;
karg.p_c_grid = karg.p_aux_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
karg.p_a_grid = karg.p_a_grid_imag_;
karg.p_b_grid = karg.p_b_grid_real_;
karg.p_c_grid = karg.p_aux_2_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
// c_imag = aux + aux_2
ave_time += launch_and_time_kernel(
stream_config,
add_kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
make_tuple(arg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
const_cast<const CDataType*>(arg.p_aux_2_grid_)),
make_tuple(arg.p_c_grid_imag_),
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(karg.p_c_grid_imag_),
Add{});
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_real_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_imag_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>;
karg.p_a_grid = karg.p_a_grid_real_;
karg.p_b_grid = karg.p_b_grid_real_;
karg.p_c_grid = karg.p_aux_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
karg.p_a_grid = karg.p_a_grid_imag_;
karg.p_b_grid = karg.p_b_grid_imag_;
karg.p_c_grid = karg.p_aux_2_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(
stream_config,
subtract_kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
make_tuple(arg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
const_cast<const CDataType*>(arg.p_aux_2_grid_)),
make_tuple(arg.p_c_grid_real_),
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(karg.p_c_grid_real_),
Subtract{});
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_real_,
arg.p_b_grid_imag_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_imag_,
arg.p_b_grid_real_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
karg.p_a_grid = karg.p_a_grid_real_;
karg.p_b_grid = karg.p_b_grid_imag_;
karg.p_c_grid = karg.p_aux_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
karg.p_a_grid = karg.p_a_grid_imag_;
karg.p_b_grid = karg.p_b_grid_real_;
karg.p_c_grid = karg.p_aux_2_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
// c_imag = aux + aux_2
ave_time += launch_and_time_kernel(
stream_config,
add_kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
make_tuple(arg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
const_cast<const CDataType*>(arg.p_aux_2_grid_)),
make_tuple(arg.p_c_grid_imag_),
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(karg.p_c_grid_imag_),
Add{});
}
......@@ -816,12 +701,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return true;
}
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& karg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
return GridwiseGemm::CheckValidity(karg);
}
// polymorphic
......@@ -837,15 +719,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_real,
CDataType* p_c_imag,
CDataType* p_workspace,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a_real,
p_a_imag,
......@@ -854,15 +736,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_real,
p_c_imag,
p_workspace,
MRaw,
NRaw,
KRaw,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op};
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K)};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -875,15 +759,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
void* p_c_real,
void* p_c_imag,
void* p_workspace,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real),
......@@ -893,15 +777,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static_cast<CDataType*>(p_c_real),
static_cast<CDataType*>(p_c_imag),
static_cast<CDataType*>(p_workspace),
MRaw,
NRaw,
KRaw,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K));
}
// polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment