Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0fb89c4a
Commit
0fb89c4a
authored
Nov 18, 2022
by
Rosty Geyyer
Browse files
Update MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
parent
6f0b21a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
979 additions
and
601 deletions
+979
-601
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_weight_nwc_kxc_nwk_dl.hpp
...u/device/impl/device_convnd_bwd_weight_nwc_kxc_nwk_dl.hpp
+979
-601
No files found.
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_weight_nwc_kxc_nwk_dl.hpp
View file @
0fb89c4a
...
...
@@ -142,178 +142,317 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
auto
K0
=
K
/
K1
;
//
const auto K0 = K / K1;
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
// const auto wei_n_x_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, X, C));
const
index_t
GemmKTotal
=
N
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)),
// make_tuple(make_pass_through_transform(N * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
// const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
// const auto in_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wi
,
C
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// const auto wei_n_x_wo_c_grid_desc = transform_tensor_descriptor(
// wei_n_x_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// wei_n_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
*
C
));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
we
i_gemmk0_gemmn_gemmk1_grid_desc
,
i
n
_gemmm_gemmn_grid_desc
);
i
n
_gemmk0_gemmn_gemmk1_grid_desc
,
we
i_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
out_n_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wo
,
K
));
const
auto
wei_k_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
,
C
));
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
));
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
// A: output tensor
const
auto
out_n_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
out_n_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_xdotslice_wtildeslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{}));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
const
auto
wei_k_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
wei_k0_k1_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<>
{},
Sequence
<
3
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
0
>
{},
Sequence
<
3
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
//
C
: input tensor
//
B
: input tensor
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x
tilde_wtilde
_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_x
_wo
_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
*
C
));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
// out_n_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
// out_n_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
// make_merge_transform(make_tuple(N, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
// in_n_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
// in_n_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_wtildeslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
// wei_k_x_c_grid_desc,
// make_tuple(make_pass_through_transform(K),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
// wei_k_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
}
}
// function end
...
...
@@ -357,185 +496,126 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
auto
K0
=
K
/
K1
;
// const auto K0 = K / K1;
// const auto out_n_ho_wo_k_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
// const auto wei_k_y_x_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
// const auto in_n_hi_wi_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
const
auto
wei_k_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
*
Y
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
// make_tuple(make_pass_through_transform(N * Ho * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
// const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_y_ho_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Ho, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Hi
*
Wi
,
C
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
we
i_gemmk0_gemmn_gemmk1_grid_desc
,
i
n
_gemmm_gemmn_grid_desc
);
i
n
_gemmk0_gemmn_gemmk1_grid_desc
,
we
i_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// C: input tensor
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
...
...
@@ -545,48 +625,222 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y
tilde_htilde_xtilde_wtilde
_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_y
_ho_x_wo
_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
// const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto YTilde = ConvStrideH / GcdStrideDilationH;
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto YDot = math::integer_divide_ceil(Y, YTilde);
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto HTilde =
// Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IHTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IHTildeSliceEnd = math::min(
// HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_ho_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Ho, I0, I0),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
// out_n_hop_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(YDot, HTilde),
// make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
// transform_tensor_descriptor(
// out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5, 6>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
// make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Hi, InLeftPadH, InRightPadH),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
// in_n_hip_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(YTilde, HTilde),
// make_tuple(ConvDilationH, ConvStrideH)),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
// in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_ytilde),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0>{},
// Sequence<>{},
// Sequence<1>{},
// Sequence<>{},
// Sequence<2>{},
// Sequence<3>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_htildeslice_wtildeslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
// wei_k_y_x_c_grid_desc,
// make_tuple(make_pass_through_transform(K),
// make_embed_transform(make_tuple(YDot, YTilde),
// make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
// transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_ytilde),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<3>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0, 1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<4>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
}
}
// function end
...
...
@@ -639,241 +893,379 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
auto
K0
=
K
/
K1
;
// const auto K0 = K / K1;
// const auto out_n_do_ho_wo_k_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
// const auto wei_k_z_y_x_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
// const auto in_n_di_hi_wi_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const
auto
out_n_do_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
K
));
const
auto
wei_k_z_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
,
Y
,
X
,
C
));
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
// make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Do
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
// const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_di_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
// make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_z_do_y_ho_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Do, Ho, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{},
// Sequence<3>{},
// Sequence<5>{},
// Sequence<0, 2, 4, 6>{},
// Sequence<7>{}),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Do
),
make_tuple
(
I1
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
we
i_gemmk0_gemmn_gemmk1_grid_desc
,
i
n
_gemmm_gemmn_grid_desc
);
i
n
_gemmk0_gemmn_gemmk1_grid_desc
,
we
i_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
ZDot
=
math
::
integer_divide_ceil
(
Z
,
ZTilde
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
DTilde
=
Do
+
math
::
integer_divide_ceil
(
ConvDilationD
*
(
Z
-
I1
),
ConvStrideD
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadD
-
ConvDilationD
*
(
ZTilde
-
I1
)),
ConvStrideD
);
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IDTildeSliceEnd
=
math
::
min
(
DTilde
,
math
::
integer_divide_ceil
(
InLeftPadD
+
Di
-
I1
,
ConvStrideD
)
+
I1
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
// const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto ZTilde = ConvStrideD / GcdStrideDilationD;
// const auto YTilde = ConvStrideH / GcdStrideDilationH;
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
// const auto YDot = math::integer_divide_ceil(Y, YTilde);
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto DTilde =
// Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
// const auto HTilde =
// Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IDTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
// const auto IHTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IDTildeSliceEnd = math::min(
// DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
// const auto IHTildeSliceEnd = math::min(
// HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
// const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
// const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_do_ho_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Do, I0, I0),
// make_pad_transform(Ho, I0, I0),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
// transform_tensor_descriptor(
// out_n_dop_hop_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(ZDot, DTilde),
// make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
// make_embed_transform(make_tuple(YDot, HTilde),
// make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto
// out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
// transform_tensor_descriptor(
// out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(ZDot, I0, ZDotSlice),
// make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7, 8>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(
// make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
// make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_di_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Di, InLeftPadD, InRightPadD),
// make_pad_transform(Hi, InLeftPadH, InRightPadH),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
// transform_tensor_descriptor(
// in_n_dip_hip_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(ZTilde, DTilde),
// make_tuple(ConvDilationD, ConvStrideD)),
// make_embed_transform(make_tuple(YTilde, HTilde),
// make_tuple(ConvDilationH, ConvStrideH)),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
// transform_tensor_descriptor(
// in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_ztilde),
// make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
// make_freeze_transform(i_ytilde),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0>{},
// Sequence<>{},
// Sequence<1>{},
// Sequence<>{},
// Sequence<2>{},
// Sequence<>{},
// Sequence<3>{},
// Sequence<4>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
// make_tuple(
// make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
// transform_tensor_descriptor(
// wei_k_z_y_x_c_grid_desc,
// make_tuple(
// make_pass_through_transform(K),
// make_embed_transform(make_tuple(ZDot, ZTilde),
// make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
// make_embed_transform(make_tuple(YDot, YTilde),
// make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
// transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(ZDot, I0, ZDotSlice),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_ztilde),
// make_freeze_transform(i_ytilde),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<3>{},
// Sequence<5>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0, 1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<5>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
// A: output tensor
const
auto
out_n_dop_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_do_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Do
,
I0
,
I0
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_dop_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
ZDot
,
DTilde
),
make_tuple
(
-
ConvDilationD
/
GcdStrideDilationD
,
I1
)),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
,
8
>
{}));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
8
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
const
auto
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_z_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
ZDot
,
ZTilde
),
make_tuple
(
ConvStrideD
/
GcdStrideDilationD
,
I1
)),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ztilde
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
5
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
4
,
0
>
{},
Sequence
<
5
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// C: input tensor
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
...
...
@@ -886,64 +1278,50 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
ZTilde
,
DTilde
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ztilde
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
,
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
we
i_gemmk0_gemmn_gemmk1_grid_desc
,
i
n
_gemmm_gemmn_grid_desc
);
i
n
_gemmk0_gemmn_gemmk1_grid_desc
,
we
i_gemmm_gemmn_grid_desc
);
}
}
// function end
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment