Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel
Commits
c03045ce
You need to sign in or sign up before continuing.
Commit
c03045ce
authored
Aug 10, 2021
by
Chao Liu
Browse files
rename
parent
b2589957
Changes
54
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
586 additions
and
586 deletions
+586
-586
README.md
README.md
+5
-5
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
...ckward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
+36
-36
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
...ward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
+36
-36
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
...orm_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
+37
-37
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...orm_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+25
-25
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...m_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+29
-29
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp
...m_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp
+29
-29
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...m_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+29
-29
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp
...orm_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp
+23
-23
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+98
-94
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
...clude/tensor_description/multi_index_transform_helper.hpp
+20
-21
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+2
-2
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+44
-42
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+21
-22
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
...el/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
+24
-26
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
...el/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
+7
-7
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+12
-13
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+55
-55
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+28
-28
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
...e/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
+26
-27
No files found.
README.md
View file @
c03045ce
...
...
@@ -78,7 +78,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{216, 256, 8}
b_k0_n_k1_grid_desc{216, 165888, 8}
c_m_n_grid_desc{ 256, 165888}
...
...
@@ -100,7 +100,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{288, 1024, 8}
b_k0_n_k1_grid_desc{288, 50176, 8}
c_m_n_grid_desc{ 1024, 50176}
...
...
@@ -122,7 +122,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{216, 165888, 8}
b_k0_n_k1_grid_desc{216, 256, 8}
c_m_n_grid_desc{ 165888, 256}
...
...
@@ -144,7 +144,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024}
...
...
@@ -166,7 +166,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024}
...
...
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -23,9 +23,9 @@ template <typename... Wei,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -102,7 +102,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const
auto
K0
=
K
/
K1
;
// weight tensor
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilda
),
...
...
@@ -114,28 +114,28 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilda
),
make_freeze_transform
(
IXTilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
transform_tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilda
),
make_freeze_transform
(
IXTilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
#if 1
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
...
...
@@ -143,7 +143,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_pass_through_transform
(
C
),
...
...
@@ -154,7 +154,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
// output tensor
// this add padding check
const
auto
out_n_hop_wop_k_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
...
...
@@ -163,7 +163,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilda
),
...
...
@@ -175,7 +175,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
...
...
@@ -197,7 +197,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence
<
5
,
6
>
{}));
#if 1
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
...
...
@@ -205,7 +205,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
...
...
@@ -215,7 +215,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
#endif
// input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -224,7 +224,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilda
,
HTilda
),
...
...
@@ -235,7 +235,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
IYTilda
),
...
...
@@ -256,7 +256,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildaslice_wtildaslice_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
))),
...
...
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -26,9 +26,9 @@ template <typename... Wei,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -106,7 +106,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
// A: output tensor
// this add padding check
const
auto
out_n_hop_wop_k_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
...
...
@@ -115,7 +115,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilda
),
...
...
@@ -127,7 +127,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
...
...
@@ -149,7 +149,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence
<
5
,
6
>
{}));
#if 1
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
...
...
@@ -157,7 +157,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
...
...
@@ -167,7 +167,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#endif
// B: weight tensor
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilda
),
...
...
@@ -179,28 +179,28 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilda
),
make_freeze_transform
(
IXTilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
transform_tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilda
),
make_freeze_transform
(
IXTilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
#if 1
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
...
...
@@ -208,7 +208,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_pass_through_transform
(
C
),
...
...
@@ -218,7 +218,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#endif
// C: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -227,7 +227,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilda
,
HTilda
),
...
...
@@ -238,7 +238,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
IYTilda
),
...
...
@@ -259,7 +259,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildaslice_wtildaslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
make_pass_through_transform
(
C
)),
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -18,9 +18,9 @@ template <typename... Wei,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -109,9 +109,9 @@ template <typename... Wei,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -147,14 +147,14 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
assert
(
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -164,15 +164,15 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -189,9 +189,9 @@ template <typename... Wei,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -229,22 +229,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
)),
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_gemmk_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -18,9 +18,9 @@ template <typename... Wei,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
Y
*
X
*
C
)),
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
...
...
@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
@@ -108,9 +108,9 @@ template <typename... Wei,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -148,22 +148,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
)),
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
C
)),
const
auto
in_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
C
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -20,9 +20,9 @@ template <typename... Wei,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -20,9 +20,9 @@ template <typename... Wei,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
Y
*
X
*
C
)),
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
...
...
@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -23,9 +23,9 @@ template <typename... In,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -70,7 +70,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -79,7 +79,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
...
...
@@ -89,36 +89,36 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: weight tensor
const
auto
wei_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
Y
*
X
*
C
)),
const
auto
wei_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -24,9 +24,9 @@ template <typename... Wei,
typename
C0Type
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -68,15 +68,15 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const
auto
C1
=
C
/
C0
;
// weight tensor
const
auto
wei_gk0_gm0_gm1_gk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic
_naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
K
)),
make_unmerge_transform
(
make_tuple
(
C0
,
C1
*
Y
*
X
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
0
>
{}));
const
auto
wei_gk0_gm0_gm1_gk1_grid_desc
=
transform_tensor_descriptor
(
make
_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
K
)),
make_unmerge_transform
(
make_tuple
(
C0
,
C1
*
Y
*
X
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
0
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -85,7 +85,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
C0
,
C1
)),
...
...
@@ -94,7 +94,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
,
7
>
{}));
const
auto
in_gk0_gn0_gn1_gk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_gk0_gn0_gn1_gk1_grid_desc
=
transform_tensor_descriptor
(
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C1
,
Y
,
X
)),
make_pass_through_transform
(
N0
),
...
...
@@ -105,17 +105,17 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
// output tensor
const
auto
out_n_k_howo_grid_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
));
const
auto
out_n0_n1_1_k_howo_grid_desc
=
transform_dynamic_tensor_descriptor
(
out_n_k_howo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
I1
,
K
)),
make_pass_through_transform
(
Ho
*
Wo
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_n0_n1_1_k_howo_grid_desc
=
transform_tensor_descriptor
(
out_n_k_howo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
I1
,
K
)),
make_pass_through_transform
(
Ho
*
Wo
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_gm0_gm1_gn0_gn1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gm0_gm1_gn0_gn1_grid_desc
=
transform_tensor_descriptor
(
out_n0_n1_1_k_howo_grid_desc
,
make_tuple
(
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
K
),
...
...
composable_kernel/include/tensor_description/
dynamic_
multi_index_transform.hpp
→
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
c03045ce
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_description/
dynamic_
multi_index_transform_helper.hpp
→
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
View file @
c03045ce
#ifndef CK_
DYNAMIC_
MULTI_INDEX_TRANSFORM_HELPER_HPP
#define CK_
DYNAMIC_
MULTI_INDEX_TRANSFORM_HELPER_HPP
#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
template
<
typename
LowLength
>
__host__
__device__
constexpr
auto
make_pass_through_transform
(
const
LowLength
&
low_length
)
{
return
Dynamic
PassThrough
<
LowLength
>
{
low_length
};
return
PassThrough
<
LowLength
>
{
low_length
};
}
template
<
typename
LowLength
,
typename
LeftPad
,
typename
RightPad
,
bool
SkipIsValidCheck
=
false
>
...
...
@@ -19,26 +19,25 @@ make_pad_transform(const LowLength& low_length,
const
RightPad
&
right_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
DynamicPad
<
LowLength
,
LeftPad
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
,
right_pad
};
return
Pad
<
LowLength
,
LeftPad
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
,
right_pad
};
}
template
<
typename
LowLength
,
typename
LeftPad
,
bool
SkipIsValidCheck
=
false
>
template
<
typename
LowLength
,
typename
LeftPad
Length
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_left_pad_transform
(
const
LowLength
&
low_length
,
const
LeftPad
&
left_pad
,
const
LeftPad
Length
&
left_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
Dynamic
LeftPad
<
LowLength
,
LeftPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
return
LeftPad
<
LowLength
,
LeftPad
Length
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
}
template
<
typename
LowLength
,
typename
RightPad
,
bool
SkipIsValidCheck
>
template
<
typename
LowLength
,
typename
RightPad
Length
,
bool
SkipIsValidCheck
>
__host__
__device__
constexpr
auto
make_right_pad_transform
(
const
LowLength
&
low_length
,
const
RightPad
&
right_pad
,
const
RightPad
Length
&
right_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
Dynamic
RightPad
<
LowLength
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
right_pad
};
return
RightPad
<
LowLength
,
RightPad
Length
,
SkipIsValidCheck
>
{
low_length
,
right_pad
};
}
template
<
typename
UpLengths
,
...
...
@@ -47,19 +46,19 @@ template <typename UpLengths,
__host__
__device__
constexpr
auto
make_embed_transform
(
const
UpLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
{
return
Dynamic
Embed
<
UpLengths
,
Coefficients
>
{
up_lengths
,
coefficients
};
return
Embed
<
UpLengths
,
Coefficients
>
{
up_lengths
,
coefficients
};
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform
(
const
LowLengths
&
low_lengths
)
{
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return
Dynamic
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
#else
#if 1
return
Dynamic
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
#else
return
Dynamic
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
#endif
#endif
}
...
...
@@ -68,7 +67,7 @@ template <typename LowLengths>
__host__
__device__
constexpr
auto
make_merge_transform_v2_magic_division
(
const
LowLengths
&
low_lengths
)
{
return
Dynamic
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
=
false
>
...
...
@@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform(
const
UpLengths
&
up_lengths
,
integral_constant
<
bool
,
Use24BitIntegerCalculation
>
=
integral_constant
<
bool
,
false
>
{})
{
return
Dynamic
UnMerge
<
UpLengths
,
Use24BitIntegerCalculation
>
{
up_lengths
};
return
UnMerge
<
UpLengths
,
Use24BitIntegerCalculation
>
{
up_lengths
};
}
template
<
typename
LowerIndex
>
__host__
__device__
constexpr
auto
make_freeze_transform
(
const
LowerIndex
&
low_idx
)
{
return
Dynamic
Freeze
<
LowerIndex
>
{
low_idx
};
return
Freeze
<
LowerIndex
>
{
low_idx
};
}
template
<
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
...
...
@@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
const
SliceBegin
&
slice_begin
,
const
SliceEnd
&
slice_end
)
{
return
Dynamic
Slice
<
LowLength
,
SliceBegin
,
SliceEnd
>
{
low_length
,
slice_begin
,
slice_end
};
return
Slice
<
LowLength
,
SliceBegin
,
SliceEnd
>
{
low_length
,
slice_begin
,
slice_end
};
}
template
<
typename
VectorSize
,
typename
UpLength
>
__host__
__device__
constexpr
auto
make_vectorize_transform
(
const
VectorSize
&
vector_size
,
const
UpLength
&
up_length
)
{
return
Dynamic
Vectorize
<
VectorSize
,
UpLength
>
{
vector_size
,
up_length
};
return
Vectorize
<
VectorSize
,
UpLength
>
{
vector_size
,
up_length
};
}
}
// namespace ck
...
...
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
composable_kernel/include/tensor_description/
dynamic_
tensor_descriptor.hpp
→
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
c03045ce
#ifndef CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HPP
#define CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HPP
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
Dynamic
TensorCoordinate
;
struct
TensorCoordinate
;
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
Dynamic
TensorCoordinateIterator
;
struct
TensorCoordinateIterator
;
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
...
...
@@ -21,7 +21,7 @@ template <typename Transforms,
typename
UpperDimensionIdss
,
typename
VisibleDimensionIds
,
typename
ElementSpaceSize
>
struct
Dynamic
TensorDescriptor
struct
TensorDescriptor
{
// TODO make these private
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
...
...
@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
Coordinate
=
Dynamic
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
using
Coordinate
=
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
// may be index_t or Number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
__host__
__device__
constexpr
Dynamic
TensorDescriptor
()
=
default
;
__host__
__device__
constexpr
TensorDescriptor
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)},
element_space_size_
{
element_space_size
}
...
...
@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
{
static_assert
(
Idx
::
Size
()
==
GetNumOfDimension
(),
"wrong! inconsistent # of dimension"
);
return
make_
dynamic_
tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
return
make_tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
}
// TODO make these private
...
...
@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"
Dynamic
TensorDescriptor, "
);
printf
(
"TensorDescriptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
...
...
@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
};
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
Dynamic
TensorCoordinate
struct
TensorCoordinate
{
// TODO make these private
static
constexpr
index_t
ndim_visible_
=
VisibleDimensionIds
::
Size
();
...
...
@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
public:
__host__
__device__
constexpr
Dynamic
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
__host__
__device__
constexpr
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
:
idx_hidden_
{
idx_hidden
}
{
}
...
...
@@ -252,16 +252,17 @@ struct DynamicTensorCoordinate
};
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
Dynamic
TensorCoordinateIterator
struct
TensorCoordinateIterator
{
// TODO make these private
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
public:
__host__
__device__
constexpr
Dynamic
TensorCoordinateIterator
()
=
default
;
__host__
__device__
constexpr
TensorCoordinateIterator
()
=
default
;
__host__
__device__
constexpr
DynamicTensorCoordinateIterator
(
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
__host__
__device__
constexpr
TensorCoordinateIterator
(
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
{
}
...
...
@@ -283,7 +284,7 @@ struct DynamicTensorCoordinateIterator
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_
dynamic_
tensor_descriptor) because template cannot be defined inside a function
// (transform_tensor_descriptor) because template cannot be defined inside a function
// template
template
<
typename
NewTransforms
>
struct
lambda_get_up_dim_num
...
...
@@ -301,10 +302,10 @@ template <typename OldTensorDescriptor,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewUpperDimensionNewVisibleIdss
>
__host__
__device__
constexpr
auto
transform_
dynamic_
tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
transform_tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
{
// sanity check
{
...
...
@@ -376,17 +377,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const
auto
element_space_size
=
old_tensor_desc
.
GetElementSpaceSize
();
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
};
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_
dynamic_
tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
VisibleIndex
&
idx_visible
)
__host__
__device__
constexpr
auto
make_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
VisibleIndex
&
idx_visible
)
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
...
...
@@ -416,13 +417,13 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
Dynamic
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
return
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
}
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
__host__
__device__
constexpr
auto
make_
dynamic_
tensor_coordinate_iterator
(
__host__
__device__
constexpr
auto
make_tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
...
...
@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
return
Dynamic
TensorCoordinateIterator
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
return
TensorCoordinateIterator
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
idx_diff_visible
,
do_transforms
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_
dynamic_
tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
make_tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
{
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
return
make_
dynamic_
tensor_coordinate_iterator
(
return
make_tensor_coordinate_iterator
(
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
}
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoordIterator
>
__host__
__device__
constexpr
void
move_dynamic_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
TensorCoordIterator
&
coord_iterator
)
__host__
__device__
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
TensorCoordIterator
&
coord_iterator
)
{
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
...
...
@@ -524,7 +526,7 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
// HACK: control UpdateLowerIndex for
Dynamic
Merge using hack
// HACK: control UpdateLowerIndex for Merge using hack
constexpr
index_t
Hack
=
decltype
(
coord_iterator
.
update_lower_index_hack_
)
::
At
(
itran
);
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
...
...
@@ -585,11 +587,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
}
template
<
typename
TensorDesc
>
using
Dynamic
TensorCoordinate_t
=
decltype
(
make_
dynamic_
tensor_coordinate
(
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
using
Dynamic
TensorCoordinateIterator_t
=
decltype
(
make_
dynamic_
tensor_coordinate_iterator
(
using
TensorCoordinateIterator_t
=
decltype
(
make_tensor_coordinate_iterator
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
...
...
composable_kernel/include/tensor_description/
dynamic_
tensor_descriptor_helper.hpp
→
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
View file @
c03045ce
#ifndef CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HELPER_HPP
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "multi_index_transform_helper.hpp"
namespace
ck
{
...
...
@@ -38,9 +38,8 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template
<
typename
...
Lengths
,
typename
...
Strides
,
typename
std
::
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
make_dynamic_naive_tensor_descriptor_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
const
Tuple
<
Strides
...
>&
strides
)
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
const
Tuple
<
Strides
...
>&
strides
)
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
...
...
@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
Number
<
1
>
{});
#endif
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
}
// Lengths... can be:
...
...
@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
// 2) Number<>, which is known at compile-time
template
<
typename
...
Lengths
>
__host__
__device__
constexpr
auto
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
const
Tuple
<
Lengths
...
>&
lengths
)
make_naive_tensor_descriptor_packed
(
const
Tuple
<
Lengths
...
>&
lengths
)
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
...
...
@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
const
auto
element_space_size
=
container_reduce
(
lengths
,
math
::
multiplies_v2
{},
Number
<
1
>
{});
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
}
template
<
typename
...
Lengths
,
typename
Align
>
__host__
__device__
constexpr
auto
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
make_naive_tensor_descriptor_aligned_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
},
Number
<
N
>
{});
return
make_
dynamic_
naive_tensor_descriptor_v2
(
lengths
,
strides
);
return
make_naive_tensor_descriptor_v2
(
lengths
,
strides
);
}
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
View file @
c03045ce
...
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
...
...
@@ -73,7 +73,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
__host__
__device__
static
constexpr
auto
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
/* a_k_m_block_desc */
)
{
const
auto
a_k_m0_m1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_k_m0_m1_block_desc
=
transform_tensor_descriptor
(
AKMBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
...
...
@@ -86,7 +86,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
/* b_k_n_block_desc */
)
{
const
auto
b_k_n0_n1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_k_n0_n1_block_desc
=
transform_tensor_descriptor
(
BKNBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
...
...
@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
private:
// A[K, M0, M1]
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M11
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_k_n0_n1_block_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N11
,
1
>
;
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M11
,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_k_n0_n1_block_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N11
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
View file @
c03045ce
...
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
...
...
@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
{
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_tensor_descriptor
(
a_block_desc_bk0_bm_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
...
...
@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
{
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_tensor_descriptor
(
b_block_desc_bk0_bn_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
...
...
@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
private:
// A[BK0, BM0, BM1, BK1]
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThreadBM11
>
{},
Number
<
BK1
>
{}));
// B[BK0, BN0, BN1, BK1]
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThreadBN11
>
{},
Number
<
BK1
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4r1
<
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
...
...
@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
Sequence
<
1
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4r1
<
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
c03045ce
...
...
@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
auto
a_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
1
>
;
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
1
>
;
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
c03045ce
...
...
@@ -2,7 +2,7 @@
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
namespace
ck
{
...
...
@@ -191,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
MRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
NRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
MRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
NRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
...
...
@@ -486,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// K1,
1
>
;
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// K1,
1
>
;
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// K1,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// K1,
1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
...
...
composable_kernel/include/tensor_operation/blockwise_
dynamic_
tensor_slice_transfer.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
c03045ce
#ifndef CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. Threadwise
Dynamic
TensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. Threadwise
Dynamic
TensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
...
...
@@ -33,16 +33,16 @@ template <index_t BlockSize,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
Dynamic
TensorSliceTransfer_v4
struct
BlockwiseTensorSliceTransfer_v4
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
Blockwise
Dynamic
TensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
__device__
constexpr
BlockwiseTensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
...
...
@@ -147,22 +147,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
Threadwise
Dynamic
TensorSliceTransfer_v3
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTensorSliceTransfer_v3
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
...
...
composable_kernel/include/tensor_operation/blockwise_
dynamic_
tensor_slice_transfer_v2.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
View file @
c03045ce
#ifndef CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. Threadwise
Dynamic
TensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. Threadwise
Dynamic
TensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
...
...
@@ -31,17 +31,16 @@ template <index_t BlockSize,
typename
DstVectorTensorContiguousDimOrder
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
Dynamic
TensorSliceTransfer_v4r1
struct
BlockwiseTensorSliceTransfer_v4r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
__device__
constexpr
BlockwiseTensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
...
...
@@ -136,20 +135,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
Threadwise
Dynamic
TensorSliceTransfer_v3r1
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorTensorLengths
,
DstVectorTensorLengths
,
SrcVectorTensorContiguousDimOrder
,
DstVectorTensorContiguousDimOrder
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTensorSliceTransfer_v3r1
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorTensorLengths
,
DstVectorTensorLengths
,
SrcVectorTensorContiguousDimOrder
,
DstVectorTensorContiguousDimOrder
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment