Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
37116c98
Commit
37116c98
authored
Nov 30, 2022
by
Rosty Geyyer
Browse files
Refactor argument preparation
parent
0fb89c4a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
123 additions
and
1024 deletions
+123
-1024
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_weight_nwc_kxc_nwk_dl.hpp
...u/device/impl/device_convnd_bwd_weight_nwc_kxc_nwk_dl.hpp
+123
-1024
No files found.
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_weight_nwc_kxc_nwk_dl.hpp
View file @
37116c98
...
...
@@ -128,12 +128,10 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
index_t
i_xtilde
=
tildes
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
0
];
...
...
@@ -142,30 +140,20 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
// const auto K0 = K / K1;
// const auto wei_n_x_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, X, C));
const
index_t
GemmKTotal
=
N
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)),
// make_tuple(make_pass_through_transform(N * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
));
...
...
@@ -184,32 +172,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
// const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
// const auto in_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wi
,
C
));
...
...
@@ -228,37 +190,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// const auto wei_n_x_wo_c_grid_desc = transform_tensor_descriptor(
// wei_n_x_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// wei_n_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
*
C
));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
...
...
@@ -331,128 +265,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
// out_n_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
// out_n_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
// make_merge_transform(make_tuple(N, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
// in_n_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
// in_n_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_wtildeslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
// wei_k_x_c_grid_desc,
// make_tuple(make_pass_through_transform(K),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
// wei_k_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
}
}
// function end
...
...
@@ -468,13 +280,10 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
...
...
@@ -496,35 +305,20 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
// const auto K0 = K / K1;
// const auto out_n_ho_wo_k_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
// const auto wei_k_y_x_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
// const auto in_n_hi_wi_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
// make_tuple(make_pass_through_transform(N * Ho * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
...
...
@@ -543,24 +337,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
// const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_y_ho_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Ho, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Hi
*
Wi
,
C
));
...
...
@@ -579,13 +355,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
...
...
@@ -663,184 +432,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
// const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto YTilde = ConvStrideH / GcdStrideDilationH;
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto YDot = math::integer_divide_ceil(Y, YTilde);
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto HTilde =
// Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IHTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IHTildeSliceEnd = math::min(
// HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_ho_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Ho, I0, I0),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
// out_n_hop_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(YDot, HTilde),
// make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
// transform_tensor_descriptor(
// out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5, 6>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
// make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Hi, InLeftPadH, InRightPadH),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
// in_n_hip_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(YTilde, HTilde),
// make_tuple(ConvDilationH, ConvStrideH)),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
// in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_ytilde),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0>{},
// Sequence<>{},
// Sequence<1>{},
// Sequence<>{},
// Sequence<2>{},
// Sequence<3>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_htildeslice_wtildeslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
// wei_k_y_x_c_grid_desc,
// make_tuple(make_pass_through_transform(K),
// make_embed_transform(make_tuple(YDot, YTilde),
// make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
// transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_ytilde),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<3>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0, 1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<4>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
}
}
// function end
...
...
@@ -857,14 +448,10 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
const
index_t
i_ztilde
=
tildes
[
0
];
const
index_t
i_ytilde
=
tildes
[
1
];
const
index_t
i_xtilde
=
tildes
[
2
];
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
...
...
@@ -893,35 +480,20 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
// const auto K0 = K / K1;
// const auto out_n_do_ho_wo_k_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
// const auto wei_k_z_y_x_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
// const auto in_n_di_hi_wi_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
// make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
...
...
@@ -940,35 +512,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
// const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_di_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
// make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_z_do_y_ho_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Do, Ho, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{},
// Sequence<3>{},
// Sequence<5>{},
// Sequence<0, 2, 4, 6>{},
// Sequence<7>{}),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
));
...
...
@@ -987,13 +530,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
));
...
...
@@ -1003,248 +539,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
else
{
// const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
// const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto ZTilde = ConvStrideD / GcdStrideDilationD;
// const auto YTilde = ConvStrideH / GcdStrideDilationH;
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
// const auto YDot = math::integer_divide_ceil(Y, YTilde);
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto DTilde =
// Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
// const auto HTilde =
// Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IDTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
// const auto IHTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IDTildeSliceEnd = math::min(
// DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
// const auto IHTildeSliceEnd = math::min(
// HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
// const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
// const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_do_ho_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Do, I0, I0),
// make_pad_transform(Ho, I0, I0),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
// transform_tensor_descriptor(
// out_n_dop_hop_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(ZDot, DTilde),
// make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
// make_embed_transform(make_tuple(YDot, HTilde),
// make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto
// out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
// transform_tensor_descriptor(
// out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(ZDot, I0, ZDotSlice),
// make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7, 8>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(
// make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
// make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_di_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Di, InLeftPadD, InRightPadD),
// make_pad_transform(Hi, InLeftPadH, InRightPadH),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
// transform_tensor_descriptor(
// in_n_dip_hip_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(ZTilde, DTilde),
// make_tuple(ConvDilationD, ConvStrideD)),
// make_embed_transform(make_tuple(YTilde, HTilde),
// make_tuple(ConvDilationH, ConvStrideH)),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
// transform_tensor_descriptor(
// in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_ztilde),
// make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
// make_freeze_transform(i_ytilde),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0>{},
// Sequence<>{},
// Sequence<1>{},
// Sequence<>{},
// Sequence<2>{},
// Sequence<>{},
// Sequence<3>{},
// Sequence<4>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
// make_tuple(
// make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
// transform_tensor_descriptor(
// wei_k_z_y_x_c_grid_desc,
// make_tuple(
// make_pass_through_transform(K),
// make_embed_transform(make_tuple(ZDot, ZTilde),
// make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
// make_embed_transform(make_tuple(YDot, YTilde),
// make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
// transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(ZDot, I0, ZDotSlice),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_ztilde),
// make_freeze_transform(i_ytilde),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<3>{},
// Sequence<5>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0, 1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<5>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
const
auto
in_n_di_hi_wi_c_grid_desc
=
...
...
@@ -1330,14 +624,14 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
1
>
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
0
}
);
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
1
);
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
}
);
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
1
);
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
...
...
@@ -1353,7 +647,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
}
);
1
);
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
...
...
@@ -1429,6 +723,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
...
...
@@ -1443,208 +740,51 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
CreateABCDesc
<
NDimSpatial
>
();
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
void
CreateABCDesc
()
{
const
index_t
ConvStrideW
=
conv_filter_strides_
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
0
];
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
index_t
X
=
filter_spatial_lengths_
[
0
];
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
XDotSlice
<=
0
)
{
continue
;
}
k_batch_
=
1
;
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
]))
{
a_grid_desc_k0_m0_m1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
descs
[
I0
]));
b_grid_desc_k0_n0_n1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
descs
[
I1
]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]));
}
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
void
CreateABCDesc
()
{
const
index_t
ConvStrideH
=
conv_filter_strides_
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides_
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations_
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
index_t
Y
=
filter_spatial_lengths_
[
0
];
const
index_t
X
=
filter_spatial_lengths_
[
1
];
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
YDotSlice
*
XDotSlice
<=
0
)
{
continue
;
}
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
k_batch_
);
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ytilde
,
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
]))
{
a_grid_desc_k0_m0_m1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
descs
[
I0
]));
b_grid_desc_k0_n0_n1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
descs
[
I1
]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]));
}
}
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
void
CreateABCDesc
()
{
const
index_t
ConvStrideD
=
conv_filter_strides_
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides_
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides_
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations_
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations_
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
2
];
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
index_t
Z
=
filter_spatial_lengths_
[
0
];
const
index_t
Y
=
filter_spatial_lengths_
[
1
];
const
index_t
X
=
filter_spatial_lengths_
[
2
];
for
(
index_t
i_ztilde
=
0
;
i_ztilde
<
ZTilde
;
++
i_ztilde
)
{
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
ZDotSlice
*
YDotSlice
*
XDotSlice
<=
0
)
{
continue
;
}
a_grid_desc_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
]))
{
a_grid_desc_k0_m0_m1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
descs
[
I0
]));
b_grid_desc_k0_n0_n1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
descs
[
I1
]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]));
}
}
}
}
a_grid_desc_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1_
);
b_grid_desc_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_k0_n_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
std
::
vector
<
AGridDesc_K0_M_K1
>
a_grid_desc_k0_m_k1_container_
;
std
::
vector
<
BGridDesc_K0_N_K1
>
b_grid_desc_k0_n_k1_container_
;
std
::
vector
<
CGridDesc_M_N
>
c_grid_desc_m_n_container_
;
std
::
vector
<
AGridDesc_K0_M
0_M1
_K1
>
a_grid_desc_k0_m
0_m1_k1_container
_
;
std
::
vector
<
BGridDesc_K0_N
0_N1
_K1
>
b_grid_desc_k0_n
0_n1_k1_container
_
;
std
::
vector
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
c_grid_desc_m0_m10_m11_n0_n10_n11_container
_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m
_k1
_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n
_k1
_
;
CGridDesc_M_N
c_grid_desc_m_n
_
;
std
::
vector
<
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1_
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1_
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
// element-wise op
OutElementwiseOperation
a_element_op_
;
WeiElementwiseOperation
b_element_op_
;
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
...
...
@@ -1657,6 +797,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
index_t
>
input_left_pads_
;
std
::
vector
<
ck
::
index_t
>
input_right_pads_
;
index_t
k_batch_
;
};
// Invoker
...
...
@@ -1664,53 +805,41 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
void
ShowInfo
(
const
Argument
&
arg
)
{
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_
container_
{"
<<
arg
.
b_grid_desc_k0_n_k1_
container_
[
i
]
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
container_
[
i
]
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
container_
[
i
]
.
GetLength
(
I2
)
<<
"}"
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_
container_
{ "
<<
arg
.
c_grid_desc_m_n_
container_
[
i
]
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
container_
[
i
]
.
GetLength
(
I1
)
<<
"}"
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I2
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I3
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I4
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I5
)
<<
" ) "
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
]))
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_container_
[
i
].
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
container_
[
i
]
);
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
,
auto
has_double_tail_k_block_loop
)
{
...
...
@@ -1728,8 +857,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
has_main_loop
,
has_double_loop
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
...
...
@@ -1737,39 +865,37 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
container_
[
i
]
,
arg
.
b_grid_desc_k0_n0_n1_k1_
container_
[
i
]
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
container_
[
i
]
,
arg
.
block_2_ctile_map_
container_
[
i
]
);
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
};
const
auto
K0
=
arg
.
a_grid_desc_k0_m0_m1_k1_
container_
[
i
]
.
GetLength
(
I0
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m0_m1_k1_
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
...
...
@@ -1806,26 +932,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
}
// // vector load A/B matrix from global memory
// if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
// arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
// arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
// {
// return false;
// }
// // vector store C matrix into global memory
// if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
// {
// return false;
// }
// // Gridwise GEMM size
// return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
// arg.b_grid_desc_kbatch_k0_n_k1_,
// arg.c_grid_desc_m_n_,
// arg.block_2_ctile_map_);
// matrix A
{
auto
srcVectorLengths
=
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
{};
...
...
@@ -1877,16 +983,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
// Gridwise GEMM size
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
]))
{
return
false
;
}
}
return
true
;
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment