Commit e67b958b authored by rocking's avatar rocking
Browse files

rename channel to c from k

parent 529f2507
......@@ -37,10 +37,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto
Make3DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_k_wos_lengths,
const std::vector<ck::index_t>& din_n_k_wos_length,
const std::vector<ck::index_t>& dout_n_k_wos_strides,
const std::vector<ck::index_t>& din_n_k_wos_strides,
Make3DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
const std::vector<ck::index_t>& din_n_c_wos_length,
const std::vector<ck::index_t>& dout_n_c_wos_strides,
const std::vector<ck::index_t>& din_n_c_wos_strides,
const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
......@@ -52,16 +52,16 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
index_t i_ytilde = tildes[1];
index_t i_xtilde = tildes[2];
const index_t N = dout_n_k_wos_lengths[0];
const index_t K = dout_n_k_wos_lengths[1];
const index_t N = dout_n_c_wos_lengths[0];
const index_t C = dout_n_c_wos_lengths[1];
const index_t Di = din_n_k_wos_length[2];
const index_t Hi = din_n_k_wos_length[3];
const index_t Wi = din_n_k_wos_length[4];
const index_t Di = din_n_c_wos_length[2];
const index_t Hi = din_n_c_wos_length[3];
const index_t Wi = din_n_c_wos_length[4];
const index_t Do = dout_n_k_wos_lengths[2];
const index_t Ho = dout_n_k_wos_lengths[3];
const index_t Wo = dout_n_k_wos_lengths[4];
const index_t Do = dout_n_c_wos_lengths[2];
const index_t Ho = dout_n_c_wos_lengths[3];
const index_t Wo = dout_n_c_wos_lengths[4];
const index_t Z = window_lengths[0];
const index_t Y = window_lengths[1];
......@@ -83,13 +83,13 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const index_t ConvDilationH = window_dilations[1];
const index_t ConvDilationW = window_dilations[2];
const auto out_n_do_ho_wo_k_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, K),
make_tuple(dout_n_k_wos_strides[0],
dout_n_k_wos_strides[2],
dout_n_k_wos_strides[3],
dout_n_k_wos_strides[4],
dout_n_k_wos_strides[1]));
const auto out_n_do_ho_wo_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C),
make_tuple(dout_n_c_wos_strides[0],
dout_n_c_wos_strides[2],
dout_n_c_wos_strides[3],
dout_n_c_wos_strides[4],
dout_n_c_wos_strides[1]));
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
......@@ -132,19 +132,19 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// Out[ReduceM, ReduceK]
const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_k_grid_desc,
const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc =
transform_tensor_descriptor(
out_n_dop_hop_wop_k_grid_desc,
out_n_dop_hop_wop_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(ZDot, DTilde),
......@@ -153,7 +153,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
......@@ -163,9 +163,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
......@@ -173,7 +173,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -192,14 +192,14 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
Sequence<7>{}));
const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, K)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * K;
const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice;
......@@ -212,27 +212,27 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_tuple(Sequence<0>{}, Sequence<1>{}));
// In[ReduceM]
const auto in_n_di_hi_wi_k_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, K),
make_tuple(din_n_k_wos_strides[0],
din_n_k_wos_strides[2],
din_n_k_wos_strides[3],
din_n_k_wos_strides[4],
din_n_k_wos_strides[1]));
const auto in_n_dip_hip_wip_k_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_k_grid_desc,
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
make_tuple(din_n_c_wos_strides[0],
din_n_c_wos_strides[2],
din_n_c_wos_strides[3],
din_n_c_wos_strides[4],
din_n_c_wos_strides[1]));
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(K)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_k_grid_desc =
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_k_grid_desc,
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(XTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
......@@ -240,7 +240,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(K)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
......@@ -249,9 +249,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_n_dtildeslice_htildeslice_wtildeslice_k_grid_desc =
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_k_grid_desc,
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
......@@ -259,7 +259,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -278,9 +278,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
Sequence<4>{}));
const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_k_grid_desc,
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, K))),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
......@@ -297,10 +297,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{
Argument(const DOutDataType* p_dout,
DInDataType* p_din,
std::vector<ck::index_t> dout_n_k_wos_lengths,
std::vector<ck::index_t> din_n_k_wos_length,
std::vector<ck::index_t> dout_n_k_wos_strides,
std::vector<ck::index_t> din_n_k_wos_strides,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
......@@ -310,10 +310,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{
ignore = p_dout;
ignore = p_din;
ignore = dout_n_k_wos_lengths;
ignore = dout_n_k_wos_strides;
ignore = din_n_k_wos_length;
ignore = din_n_k_wos_strides;
ignore = dout_n_c_wos_lengths;
ignore = dout_n_c_wos_strides;
ignore = din_n_c_wos_length;
ignore = din_n_c_wos_strides;
ignore = window_lengths;
ignore = window_strides;
ignore = window_dilations;
......@@ -383,10 +383,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
void* p_din,
std::vector<ck::index_t> dout_n_k_wos_lengths,
std::vector<ck::index_t> din_n_k_wos_length,
std::vector<ck::index_t> dout_n_k_wos_strides,
std::vector<ck::index_t> din_n_k_wos_strides,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
......@@ -395,10 +395,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<DInDataType*>(p_din),
dout_n_k_wos_lengths,
din_n_k_wos_length,
dout_n_k_wos_strides,
din_n_k_wos_strides,
dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment