"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "c1159e3c46c194e367e44224180c1f1e922e3fc4"
Commit b9f23971 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Update the argument

parent cb4b511a
......@@ -719,12 +719,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std::array<ck::index_t, NDimSpatial> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
OutElementwiseOperation out_element_op,
ck::index_t split_k)
: p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
......@@ -739,10 +740,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
k_batch_ = 1;
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N,
......@@ -757,12 +757,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads,
k_batch_);
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
a_grid_desc_k0_m0_m1_k1_ = GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ = GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
a_grid_desc_kbatch_k0_m0_m1_k1_ = GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_);
b_grid_desc_kbatch_k0_n0_n1_k1_ = GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
......@@ -771,12 +771,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -809,16 +809,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
void ShowInfo(const Argument& arg)
{
std::cout << "arg.a_grid_desc_k0_m_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}"
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}"
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_{ "
......@@ -832,16 +834,16 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
"wrong! GridwiseGemm GridwiseGemmDl_km_kn_mn_v1r3 has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
......@@ -867,13 +869,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.a_grid_desc_kbatch_k0_m0_m1_k1_,
arg.b_grid_desc_kbatch_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
......@@ -985,8 +987,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_);
}
......@@ -1030,7 +1032,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
out_element_op,
split_k};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -1071,7 +1074,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
out_element_op,
split_k);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment