"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5afbcce176cd4e8ec08f43ee9fae2d6562edf54c"
Unverified Commit 9eae73df authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Simplify kernel argument of device operator Device(Batched)GemmXdl<> (#723)



* Remove M/N/KPad local variables

* Use M/N/KPad to name padded lengths

* Replace duplicated local variable by parameters

* Rename variables M/N/KRaw to M/N/K

* Move AK0/BK0 compute logic into GridwiseGemm

* Use macro to shorten code

* Move CalculateGridSize() logic into GridwiseGemm

* Add comment to credit the implementation source

* Reuse the existing implementation

* Remove no-longer used data members

* Remove elementwise-op objects from interfaces

* Reserve kernel arg as whole object in interfaces

* Remove redundant data member

* Make 3rd type parameter optional

* Remove unnesscary type parameters

* Remove no-longer used descriptor-creation methods

* Move kernel arg type definition into GridwiseGemm

* Add macro to switch between code sections

* Move argument field computing logic into device op side

* Make utility method 'static'

* Declare special methods

* Unify MakeArgument() usage

* Adapt the new GridwiseGemm interface

* Push-down class 'GridwiseGemm::Argument' fields

* Remove no-longer used methods

* Add unused parameters

* Force copying parameters in 'Embed' ctor

* Remove no-longer used descriptors

* Fallback change on BaseArgument

* Remove macro 'INTEGER_DIVIDE_CEIL'

* Make variable naming more consistent

* Make sure methods are only invoked on right place

* Remove tailing underscore in public attribute name

* Remove necessary methods

* Hide computing logic of derived attributes

* Make new 'Embed' ctor only available for device code

* Make sure 'Embed' type args are not references

* Move check for karg.K into CheckValidity()

* Remove more integer division logic form device code

* Undo changes on Embed

* Separate 'Problem' concept out from 'Argument'

* Add overloaded version of __builtin_amdgcn_readfirstlane()

* Remove 'static' specifiers

* Remove more 'static' specifier

* Replace unsigne char by std::byte

* Add 'const' specifier to never changing variable

* Add 'inline' specifier to funcion definition

* Share same name for kernel interfaces

* Fix wrong boundar calculation logic

* Leave the third template arg for compatibility

* Remove unnecessary parameters

* Fix wrong error message (for type name)

* Create descriptor on device side

* Fix wrong debug message

* Remove no-longer used data members

* Rename type trait

* Remove std:: qualifier from standard types

* Replace 'size_t' by 'unsigned'

* Use type alias to hint usage

* Replace static_for<> by ordinary 'for' loop

* Reject unsupported argument

* Rename readfirstlane() to amd_wave_read_first_lane()

* Rename file readfirstlance.hpp as amd_wave_read_first_lane.hpp

* Update function calls

* Reorder statements

* Re-format files

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent b94fd0b2
...@@ -81,32 +81,33 @@ int run_conv_bwd_data(bool do_verification, ...@@ -81,32 +81,33 @@ int run_conv_bwd_data(bool do_verification,
in_device_buf.SetZero(); in_device_buf.SetZero();
// do GEMM // do GEMM
auto conv = DeviceConvNdBwdDataInstance{}; auto conv = DeviceConvNdBwdDataInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), auto argument =
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
conv_param.N_, static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.K_, conv_param.N_,
conv_param.C_, conv_param.K_,
conv_param.input_spatial_lengths_, conv_param.C_,
conv_param.filter_spatial_lengths_, conv_param.input_spatial_lengths_,
conv_param.GetOutputSpatialLengths(), conv_param.filter_spatial_lengths_,
conv_param.conv_filter_strides_, conv_param.GetOutputSpatialLengths(),
conv_param.conv_filter_dilations_, conv_param.conv_filter_strides_,
conv_param.input_left_pads_, conv_param.conv_filter_dilations_,
conv_param.input_right_pads_, conv_param.input_left_pads_,
in_element_op, conv_param.input_right_pads_,
wei_element_op, in_element_op,
out_element_op); wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
if(!conv.IsSupportedArgument(argument.get()))
{ {
std::cout << "Not support,please check parameters or device"; std::cout << "Not support,please check parameters or device";
return 0; return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops(); std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>(); std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
......
...@@ -379,9 +379,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -379,9 +379,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -428,20 +425,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -428,20 +425,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads)
ck::index_t M01,
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_out_grid}, : p_a_grid_{p_out_grid},
p_b_grid_{p_wei_grid}, p_b_grid_{p_wei_grid},
p_c_grid_{p_in_grid}, p_c_grid_{p_in_grid},
M01_{M01},
N01_{N01},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_N_{N}, Conv_N_{N},
Conv_K_{K}, Conv_K_{K},
Conv_C_{C}, Conv_C_{C},
...@@ -495,18 +482,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -495,18 +482,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]); c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
} }
} }
} }
...@@ -517,14 +492,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -517,14 +492,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_; std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_; std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_; std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_;
index_t N01_;
OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_N_; index_t Conv_N_;
index_t Conv_K_; index_t Conv_K_;
...@@ -567,103 +534,68 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -567,103 +534,68 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl; << std::endl;
std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< " ) " << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i], arg.c_grid_desc_m_n_container_[i]))
arg.block_2_ctile_map_container_[i]))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
} }
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( const auto [gdx, gdy, gdz] =
arg.c_grid_desc_m_n_container_[i]); GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, DeviceOp::AGridDesc_K0_M_K1,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, DeviceOp::BGridDesc_K0_N_K1,
remove_reference_t< DeviceOp::CGridDesc_M_N,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, true>;
OutElementwiseOperation,
WeiElementwiseOperation, ave_time += launch_and_time_kernel(stream_config,
InElementwiseOperation, kernel,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, dim3(gdx, gdy, gdz),
true>; dim3(BlockSize),
0,
ave_time += launch_and_time_kernel( arg.p_a_grid_,
stream_config, arg.p_b_grid_,
kernel, arg.p_c_grid_,
dim3(grid_size), arg.a_grid_desc_k0_m_k1_container_[i],
dim3(BlockSize), arg.b_grid_desc_k0_n_k1_container_[i],
0, arg.c_grid_desc_m_n_container_[i]);
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i],
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, DeviceOp::AGridDesc_K0_M_K1,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, DeviceOp::BGridDesc_K0_N_K1,
remove_reference_t< DeviceOp::CGridDesc_M_N,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, false>;
OutElementwiseOperation,
WeiElementwiseOperation, ave_time += launch_and_time_kernel(stream_config,
InElementwiseOperation, kernel,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, dim3(gdx, gdy, gdz),
false>; dim3(BlockSize),
0,
ave_time += launch_and_time_kernel( arg.p_a_grid_,
stream_config, arg.p_b_grid_,
kernel, arg.p_c_grid_,
dim3(grid_size), arg.a_grid_desc_k0_m_k1_container_[i],
dim3(BlockSize), arg.b_grid_desc_k0_n_k1_container_[i],
0, arg.c_grid_desc_m_n_container_[i]);
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i],
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
} }
} }
return ave_time; return ave_time;
...@@ -716,8 +648,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -716,8 +648,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i], arg.c_grid_desc_m_n_container_[i]))
arg.block_2_ctile_map_container_[i]))
{ {
return false; return false;
} }
...@@ -742,10 +673,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -742,10 +673,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads)
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
...@@ -759,12 +687,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -759,12 +687,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads};
1,
1,
in_element_op,
wei_element_op,
out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -783,9 +706,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -783,9 +706,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation,
OutElementwiseOperation out_element_op) override OutElementwiseOperation) override
{ {
return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid), static_cast<const WeiDataType*>(p_wei_grid),
...@@ -799,12 +722,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -799,12 +722,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads);
1,
1,
in_element_op,
wei_element_op,
out_element_op);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -329,9 +329,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -329,9 +329,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -378,25 +375,13 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -378,25 +375,13 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads)
ck::index_t M01,
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid}, : p_a_grid_{p_in_grid},
p_b_grid_{p_wei_grid}, p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid}, p_c_grid_{p_out_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op},
Conv_N_{N}, Conv_N_{N},
Conv_K_{K}, Conv_K_{K},
Conv_C_{C}, Conv_C_{C},
...@@ -420,17 +405,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -420,17 +405,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
a_grid_desc_k0_m_k1_ = descs[I0]; a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1]; b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
}
} }
// private: // private:
...@@ -440,14 +414,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -440,14 +414,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_N_; index_t Conv_N_;
index_t Conv_K_; index_t Conv_K_;
...@@ -479,17 +445,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -479,17 +445,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(
arg.b_grid_desc_k0_n_k1_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
const index_t grid_size = const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
...@@ -498,22 +461,18 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -498,22 +461,18 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, DeviceOp::AGridDesc_K0_M_K1,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, DeviceOp::BGridDesc_K0_N_K1,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, DeviceOp::CGridDesc_M_N,
InElementwiseOperation, true>;
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
...@@ -521,30 +480,22 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -521,30 +480,22 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m_n_);
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
arg.block_2_ctile_map_);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, DeviceOp::AGridDesc_K0_M_K1,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, DeviceOp::BGridDesc_K0_N_K1,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, DeviceOp::CGridDesc_M_N,
InElementwiseOperation, false>;
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
...@@ -552,11 +503,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -552,11 +503,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m_n_);
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -616,10 +563,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -616,10 +563,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(
arg.b_grid_desc_k0_n_k1_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -639,10 +584,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -639,10 +584,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads)
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
...@@ -656,12 +598,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -656,12 +598,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads};
1,
1,
in_element_op,
wei_element_op,
out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -680,9 +617,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -680,9 +617,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation,
OutElementwiseOperation out_element_op) override OutElementwiseOperation) override
{ {
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid), static_cast<const WeiDataType*>(p_wei_grid),
...@@ -696,12 +633,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -696,12 +633,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads);
1,
1,
in_element_op,
wei_element_op,
out_element_op);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -980,9 +980,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -980,9 +980,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -1029,20 +1026,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1029,20 +1026,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads)
ck::index_t M01,
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_out_grid}, : p_a_grid_{p_out_grid},
p_b_grid_{p_wei_grid}, p_b_grid_{p_wei_grid},
p_c_grid_{p_in_grid}, p_c_grid_{p_in_grid},
M01_{M01},
N01_{N01},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_N_{N}, Conv_N_{N},
Conv_K_{K}, Conv_K_{K},
Conv_C_{C}, Conv_C_{C},
...@@ -1092,17 +1079,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1092,17 +1079,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]); c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
...@@ -1150,18 +1126,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1150,18 +1126,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]); c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
} }
} }
} }
...@@ -1218,19 +1182,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1218,19 +1182,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]); c_grid_desc_m_n_container_.push_back(descs[I2]);
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
descs[I2]));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
} }
} }
} }
...@@ -1242,11 +1193,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1242,11 +1193,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_; std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_; std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_; std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_;
index_t N01_;
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_; WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_; InElementwiseOperation c_element_op_;
...@@ -1276,123 +1222,84 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1276,123 +1222,84 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{ {
#if DEBUG_LOG #if DEBUG_LOG
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" std::cout << "arg.a_grid_desc_k0_m_k1{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl; << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{" std::cout << "arg.b_grid_desc_k0_n_k1{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl; << std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ " std::cout << "arg.c_grid_desc_m_n{"
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl; << std::endl;
std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7)
<< " ) " << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i], arg.c_grid_desc_m_n_container_[i]))
arg.block_2_ctile_map_container_[i]))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( const auto [gdx, gdy, gdz] =
arg.c_grid_desc_m_n_container_[i]); GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, DeviceOp::AGridDesc_K0_M_K1,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, DeviceOp::BGridDesc_K0_N_K1,
remove_reference_t< DeviceOp::CGridDesc_M_N,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, true>;
OutElementwiseOperation,
WeiElementwiseOperation, ave_time += launch_and_time_kernel(stream_config,
InElementwiseOperation, kernel,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, dim3(gdx, gdy, gdz),
true>; dim3(BlockSize),
0,
ave_time += launch_and_time_kernel( arg.p_a_grid_,
stream_config, arg.p_b_grid_,
kernel, arg.p_c_grid_,
dim3(grid_size), arg.a_grid_desc_k0_m_k1_container_[i],
dim3(BlockSize), arg.b_grid_desc_k0_n_k1_container_[i],
0, arg.c_grid_desc_m_n_container_[i]);
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i],
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel =
GridwiseGemm, kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, DeviceOp::AGridDesc_K0_M_K1,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, DeviceOp::BGridDesc_K0_N_K1,
remove_reference_t< DeviceOp::CGridDesc_M_N,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, false>;
OutElementwiseOperation,
WeiElementwiseOperation, ave_time += launch_and_time_kernel(stream_config,
InElementwiseOperation, kernel,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, dim3(gdx, gdy, gdz),
false>; dim3(BlockSize),
0,
ave_time += launch_and_time_kernel( arg.p_a_grid_,
stream_config, arg.p_b_grid_,
kernel, arg.p_c_grid_,
dim3(grid_size), arg.a_grid_desc_k0_m_k1_container_[i],
dim3(BlockSize), arg.b_grid_desc_k0_n_k1_container_[i],
0, arg.c_grid_desc_m_n_container_[i]);
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i],
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
} }
} }
return ave_time; return ave_time;
...@@ -1446,8 +1353,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1446,8 +1353,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i], arg.c_grid_desc_m_n_container_[i]))
arg.block_2_ctile_map_container_[i]))
{ {
return false; return false;
} }
...@@ -1472,10 +1378,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1472,10 +1378,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads)
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
...@@ -1489,12 +1392,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1489,12 +1392,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads};
1,
1,
in_element_op,
wei_element_op,
out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -1513,9 +1411,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1513,9 +1411,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation,
OutElementwiseOperation out_element_op) override OutElementwiseOperation) override
{ {
return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid), static_cast<const WeiDataType*>(p_wei_grid),
...@@ -1529,12 +1427,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1529,12 +1427,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads);
1,
1,
in_element_op,
wei_element_op,
out_element_op);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -75,132 +75,20 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -75,132 +75,20 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, ALayout,
BGridDesc_K0_N_K1, BLayout,
CGridDesc_M_N, CLayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpec,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
...@@ -232,173 +120,41 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -232,173 +120,41 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
// Argument using Argument = typename GridwiseGemm::Argument;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
kraw_{K}
{
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t kraw_;
};
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmXdl::Argument; float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(stream_config.log_level_ > 0)
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) karg.Print();
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(karg))
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
} }
const index_t grid_size = const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0; float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ave_time = launch_and_time_kernel(
CDataType, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, false>;
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ave_time = launch_and_time_kernel(
CDataType, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -418,7 +174,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -418,7 +174,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& karg)
{ {
if(ck::get_device_name() == "gfx908") if(ck::get_device_name() == "gfx908")
{ {
...@@ -441,15 +197,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -441,15 +197,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false; return false;
} }
if(arg.kraw_ % K1 != 0) if(karg.K % K1 != 0)
{ {
return false; return false;
} }
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(karg);
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
} }
// polymorphic // polymorphic
...@@ -467,24 +220,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -467,24 +220,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CElementwiseOperation c_element_op) CElementwiseOperation)
{ {
return Argument{p_a, return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -499,9 +239,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -499,9 +239,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CElementwiseOperation c_element_op) override CElementwiseOperation) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -511,12 +251,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -511,12 +251,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC);
1,
1,
a_element_op,
b_element_op,
c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -134,6 +134,14 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -134,6 +134,14 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{ {
} }
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: BlockToCTileMap_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
...@@ -142,6 +150,18 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -142,6 +150,18 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
return M0 * N0; return M0 * N0;
} }
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{ {
...@@ -222,30 +242,12 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -222,30 +242,12 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
index_t M01_; index_t M01_;
}; };
// keep the redundant type argument for backward compatibility
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N> template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{ {
using Parent = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>; using BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>::
BlockToCTileMap_M00_N0_M01Adapt;
using Parent::I0;
using Parent::I1;
using Parent::Parent;
using Parent::operator=;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: Parent(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return Parent::CalculateGridSize(c_grid_desc_m_n.GetLength(I0),
c_grid_desc_m_n.GetLength(I1));
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
}; };
// 2D slices of column-vectors in 3D space // 2D slices of column-vectors in 3D space
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment