Commit b9f23971 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Update the argument

parent cb4b511a
...@@ -719,12 +719,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -719,12 +719,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std::array<ck::index_t, NDimSpatial> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op,
ck::index_t split_k)
: p_a_grid_{p_out_grid}, : p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid}, p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid}, p_c_grid_{p_wei_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
...@@ -739,10 +740,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -739,10 +740,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads},
k_batch_{split_k}
{ {
k_batch_ = 1;
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N, N,
...@@ -757,12 +757,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -757,12 +757,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads, input_right_pads,
k_batch_); k_batch_);
a_grid_desc_k0_m_k1_ = descs[I0]; a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
a_grid_desc_k0_m0_m1_k1_ = GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); a_grid_desc_kbatch_k0_m0_m1_k1_ = GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ = GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); b_grid_desc_kbatch_k0_n0_n1_k1_ = GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); c_grid_desc_m0_m10_m11_n0_n10_n11_ = GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
} }
...@@ -771,12 +771,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -771,12 +771,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_; DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -809,16 +809,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -809,16 +809,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
void ShowInfo(const Argument& arg) void ShowInfo(const Argument& arg)
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}"
<< std::endl; << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}"
<< std::endl; << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " std::cout << "arg.c_grid_desc_m_n_{ "
...@@ -832,16 +834,16 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -832,16 +834,16 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
ShowInfo(arg); ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_)) arg.c_grid_desc_m_n_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting"); "wrong! GridwiseGemm GridwiseGemmDl_km_kn_mn_v1r3 has invalid setting");
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
auto launch_kernel = [&](auto has_main_k_block_loop, auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) { auto has_double_tail_k_block_loop) {
...@@ -867,13 +869,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -867,13 +869,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_, arg.a_grid_desc_kbatch_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_, arg.b_grid_desc_kbatch_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}; };
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop = const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
...@@ -985,8 +987,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -985,8 +987,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_); arg.c_grid_desc_m_n_);
} }
...@@ -1030,7 +1032,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -1030,7 +1032,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads, input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op}; out_element_op,
split_k};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -1071,7 +1074,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl ...@@ -1071,7 +1074,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads, input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op,
split_k);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment