Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c03045ce
"convert/git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "107f6959299d0cc18ef15df23cee5eaae8ffbf4e"
Commit
c03045ce
authored
Aug 10, 2021
by
Chao Liu
Browse files
rename
parent
b2589957
Changes
54
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
586 additions
and
586 deletions
+586
-586
README.md
README.md
+5
-5
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
...ckward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
+36
-36
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
...ward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
+36
-36
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
...orm_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
+37
-37
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...orm_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+25
-25
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...m_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+29
-29
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp
...m_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp
+29
-29
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...m_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+29
-29
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp
...orm_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp
+23
-23
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+98
-94
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
...clude/tensor_description/multi_index_transform_helper.hpp
+20
-21
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+2
-2
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+44
-42
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+21
-22
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
...el/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
+24
-26
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
...el/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
+7
-7
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+12
-13
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+55
-55
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+28
-28
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
...e/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
+26
-27
No files found.
README.md
View file @
c03045ce
...
...
@@ -78,7 +78,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{216, 256, 8}
b_k0_n_k1_grid_desc{216, 165888, 8}
c_m_n_grid_desc{ 256, 165888}
...
...
@@ -100,7 +100,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{288, 1024, 8}
b_k0_n_k1_grid_desc{288, 50176, 8}
c_m_n_grid_desc{ 1024, 50176}
...
...
@@ -122,7 +122,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{216, 165888, 8}
b_k0_n_k1_grid_desc{216, 256, 8}
c_m_n_grid_desc{ 165888, 256}
...
...
@@ -144,7 +144,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024}
...
...
@@ -166,7 +166,7 @@ InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_
dynamic_
convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024}
...
...
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -23,9 +23,9 @@ template <typename... Wei,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -102,7 +102,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const
auto
K0
=
K
/
K1
;
// weight tensor
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilda
),
...
...
@@ -114,28 +114,28 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilda
),
make_freeze_transform
(
IXTilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
transform_tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilda
),
make_freeze_transform
(
IXTilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
#if 1
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
...
...
@@ -143,7 +143,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_pass_through_transform
(
C
),
...
...
@@ -154,7 +154,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
// output tensor
// this add padding check
const
auto
out_n_hop_wop_k_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
...
...
@@ -163,7 +163,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilda
),
...
...
@@ -175,7 +175,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
...
...
@@ -197,7 +197,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence
<
5
,
6
>
{}));
#if 1
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
...
...
@@ -205,7 +205,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
...
...
@@ -215,7 +215,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
#endif
// input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -224,7 +224,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilda
,
HTilda
),
...
...
@@ -235,7 +235,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
IYTilda
),
...
...
@@ -256,7 +256,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildaslice_wtildaslice_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
))),
...
...
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -26,9 +26,9 @@ template <typename... Wei,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -106,7 +106,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
// A: output tensor
// this add padding check
const
auto
out_n_hop_wop_k_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
...
...
@@ -115,7 +115,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilda
),
...
...
@@ -127,7 +127,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
...
...
@@ -149,7 +149,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence
<
5
,
6
>
{}));
#if 1
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
...
...
@@ -157,7 +157,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
...
...
@@ -167,7 +167,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#endif
// B: weight tensor
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilda
),
...
...
@@ -179,28 +179,28 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilda
),
make_freeze_transform
(
IXTilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
transform_tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilda
),
make_freeze_transform
(
IXTilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
#if 1
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
...
...
@@ -208,7 +208,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_pass_through_transform
(
C
),
...
...
@@ -218,7 +218,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#endif
// C: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -227,7 +227,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilda
,
HTilda
),
...
...
@@ -238,7 +238,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
IYTilda
),
...
...
@@ -259,7 +259,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildaslice_wtildaslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
make_pass_through_transform
(
C
)),
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -18,9 +18,9 @@ template <typename... Wei,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -109,9 +109,9 @@ template <typename... Wei,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -147,14 +147,14 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
assert
(
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -164,15 +164,15 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -189,9 +189,9 @@ template <typename... Wei,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -229,22 +229,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
)),
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_gemmk_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -18,9 +18,9 @@ template <typename... Wei,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
Y
*
X
*
C
)),
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
...
...
@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
@@ -108,9 +108,9 @@ template <typename... Wei,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -148,22 +148,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
)),
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
C
)),
const
auto
in_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
C
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -20,9 +20,9 @@ template <typename... Wei,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -20,9 +20,9 @@ template <typename... Wei,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
Y
*
X
*
C
)),
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
...
...
@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -23,9 +23,9 @@ template <typename... In,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -70,7 +70,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
...
...
@@ -79,7 +79,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
...
...
@@ -89,36 +89,36 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmm_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: weight tensor
const
auto
wei_gemmk_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
Y
*
X
*
C
)),
const
auto
wei_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
@@ -24,9 +24,9 @@ template <typename... Wei,
typename
C0Type
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad
(
const
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
...
...
@@ -68,15 +68,15 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const
auto
C1
=
C
/
C0
;
// weight tensor
const
auto
wei_gk0_gm0_gm1_gk1_grid_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic
_naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
K
)),
make_unmerge_transform
(
make_tuple
(
C0
,
C1
*
Y
*
X
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
0
>
{}));
const
auto
wei_gk0_gm0_gm1_gk1_grid_desc
=
transform_tensor_descriptor
(
make
_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
K
)),
make_unmerge_transform
(
make_tuple
(
C0
,
C1
*
Y
*
X
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
0
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
...
...
@@ -85,7 +85,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
C0
,
C1
)),
...
...
@@ -94,7 +94,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
,
7
>
{}));
const
auto
in_gk0_gn0_gn1_gk1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_gk0_gn0_gn1_gk1_grid_desc
=
transform_tensor_descriptor
(
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C1
,
Y
,
X
)),
make_pass_through_transform
(
N0
),
...
...
@@ -105,17 +105,17 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
// output tensor
const
auto
out_n_k_howo_grid_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
));
const
auto
out_n0_n1_1_k_howo_grid_desc
=
transform_dynamic_tensor_descriptor
(
out_n_k_howo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
I1
,
K
)),
make_pass_through_transform
(
Ho
*
Wo
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_n0_n1_1_k_howo_grid_desc
=
transform_tensor_descriptor
(
out_n_k_howo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_unmerge_transform
(
make_tuple
(
I1
,
K
)),
make_pass_through_transform
(
Ho
*
Wo
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_gm0_gm1_gn0_gn1_grid_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_gm0_gm1_gn0_gn1_grid_desc
=
transform_tensor_descriptor
(
out_n0_n1_1_k_howo_grid_desc
,
make_tuple
(
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
K
),
...
...
composable_kernel/include/tensor_description/
dynamic_
multi_index_transform.hpp
→
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
c03045ce
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_description/
dynamic_
multi_index_transform_helper.hpp
→
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
View file @
c03045ce
#ifndef CK_
DYNAMIC_
MULTI_INDEX_TRANSFORM_HELPER_HPP
#define CK_
DYNAMIC_
MULTI_INDEX_TRANSFORM_HELPER_HPP
#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
template
<
typename
LowLength
>
__host__
__device__
constexpr
auto
make_pass_through_transform
(
const
LowLength
&
low_length
)
{
return
Dynamic
PassThrough
<
LowLength
>
{
low_length
};
return
PassThrough
<
LowLength
>
{
low_length
};
}
template
<
typename
LowLength
,
typename
LeftPad
,
typename
RightPad
,
bool
SkipIsValidCheck
=
false
>
...
...
@@ -19,26 +19,25 @@ make_pad_transform(const LowLength& low_length,
const
RightPad
&
right_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
DynamicPad
<
LowLength
,
LeftPad
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
,
right_pad
};
return
Pad
<
LowLength
,
LeftPad
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
,
right_pad
};
}
template
<
typename
LowLength
,
typename
LeftPad
,
bool
SkipIsValidCheck
=
false
>
template
<
typename
LowLength
,
typename
LeftPad
Length
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_left_pad_transform
(
const
LowLength
&
low_length
,
const
LeftPad
&
left_pad
,
const
LeftPad
Length
&
left_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
Dynamic
LeftPad
<
LowLength
,
LeftPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
return
LeftPad
<
LowLength
,
LeftPad
Length
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
}
template
<
typename
LowLength
,
typename
RightPad
,
bool
SkipIsValidCheck
>
template
<
typename
LowLength
,
typename
RightPad
Length
,
bool
SkipIsValidCheck
>
__host__
__device__
constexpr
auto
make_right_pad_transform
(
const
LowLength
&
low_length
,
const
RightPad
&
right_pad
,
const
RightPad
Length
&
right_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
Dynamic
RightPad
<
LowLength
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
right_pad
};
return
RightPad
<
LowLength
,
RightPad
Length
,
SkipIsValidCheck
>
{
low_length
,
right_pad
};
}
template
<
typename
UpLengths
,
...
...
@@ -47,19 +46,19 @@ template <typename UpLengths,
__host__
__device__
constexpr
auto
make_embed_transform
(
const
UpLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
{
return
Dynamic
Embed
<
UpLengths
,
Coefficients
>
{
up_lengths
,
coefficients
};
return
Embed
<
UpLengths
,
Coefficients
>
{
up_lengths
,
coefficients
};
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform
(
const
LowLengths
&
low_lengths
)
{
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return
Dynamic
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
#else
#if 1
return
Dynamic
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
#else
return
Dynamic
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
#endif
#endif
}
...
...
@@ -68,7 +67,7 @@ template <typename LowLengths>
__host__
__device__
constexpr
auto
make_merge_transform_v2_magic_division
(
const
LowLengths
&
low_lengths
)
{
return
Dynamic
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
=
false
>
...
...
@@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform(
const
UpLengths
&
up_lengths
,
integral_constant
<
bool
,
Use24BitIntegerCalculation
>
=
integral_constant
<
bool
,
false
>
{})
{
return
Dynamic
UnMerge
<
UpLengths
,
Use24BitIntegerCalculation
>
{
up_lengths
};
return
UnMerge
<
UpLengths
,
Use24BitIntegerCalculation
>
{
up_lengths
};
}
template
<
typename
LowerIndex
>
__host__
__device__
constexpr
auto
make_freeze_transform
(
const
LowerIndex
&
low_idx
)
{
return
Dynamic
Freeze
<
LowerIndex
>
{
low_idx
};
return
Freeze
<
LowerIndex
>
{
low_idx
};
}
template
<
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
...
...
@@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
const
SliceBegin
&
slice_begin
,
const
SliceEnd
&
slice_end
)
{
return
Dynamic
Slice
<
LowLength
,
SliceBegin
,
SliceEnd
>
{
low_length
,
slice_begin
,
slice_end
};
return
Slice
<
LowLength
,
SliceBegin
,
SliceEnd
>
{
low_length
,
slice_begin
,
slice_end
};
}
template
<
typename
VectorSize
,
typename
UpLength
>
__host__
__device__
constexpr
auto
make_vectorize_transform
(
const
VectorSize
&
vector_size
,
const
UpLength
&
up_length
)
{
return
Dynamic
Vectorize
<
VectorSize
,
UpLength
>
{
vector_size
,
up_length
};
return
Vectorize
<
VectorSize
,
UpLength
>
{
vector_size
,
up_length
};
}
}
// namespace ck
...
...
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
c03045ce
...
...
@@ -2,8 +2,8 @@
#define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
...
...
composable_kernel/include/tensor_description/
dynamic_
tensor_descriptor.hpp
→
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
c03045ce
#ifndef CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HPP
#define CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HPP
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "
dynamic_
multi_index_transform.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
Dynamic
TensorCoordinate
;
struct
TensorCoordinate
;
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
Dynamic
TensorCoordinateIterator
;
struct
TensorCoordinateIterator
;
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
...
...
@@ -21,7 +21,7 @@ template <typename Transforms,
typename
UpperDimensionIdss
,
typename
VisibleDimensionIds
,
typename
ElementSpaceSize
>
struct
Dynamic
TensorDescriptor
struct
TensorDescriptor
{
// TODO make these private
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
...
...
@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
Coordinate
=
Dynamic
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
using
Coordinate
=
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
// may be index_t or Number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
__host__
__device__
constexpr
Dynamic
TensorDescriptor
()
=
default
;
__host__
__device__
constexpr
TensorDescriptor
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)},
element_space_size_
{
element_space_size
}
...
...
@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
{
static_assert
(
Idx
::
Size
()
==
GetNumOfDimension
(),
"wrong! inconsistent # of dimension"
);
return
make_
dynamic_
tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
return
make_tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
}
// TODO make these private
...
...
@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"
Dynamic
TensorDescriptor, "
);
printf
(
"TensorDescriptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
...
...
@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
};
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
Dynamic
TensorCoordinate
struct
TensorCoordinate
{
// TODO make these private
static
constexpr
index_t
ndim_visible_
=
VisibleDimensionIds
::
Size
();
...
...
@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
public:
__host__
__device__
constexpr
Dynamic
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
Dynamic
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
__host__
__device__
constexpr
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
:
idx_hidden_
{
idx_hidden
}
{
}
...
...
@@ -252,16 +252,17 @@ struct DynamicTensorCoordinate
};
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
Dynamic
TensorCoordinateIterator
struct
TensorCoordinateIterator
{
// TODO make these private
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
public:
__host__
__device__
constexpr
Dynamic
TensorCoordinateIterator
()
=
default
;
__host__
__device__
constexpr
TensorCoordinateIterator
()
=
default
;
__host__
__device__
constexpr
DynamicTensorCoordinateIterator
(
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
__host__
__device__
constexpr
TensorCoordinateIterator
(
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
{
}
...
...
@@ -283,7 +284,7 @@ struct DynamicTensorCoordinateIterator
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_
dynamic_
tensor_descriptor) because template cannot be defined inside a function
// (transform_tensor_descriptor) because template cannot be defined inside a function
// template
template
<
typename
NewTransforms
>
struct
lambda_get_up_dim_num
...
...
@@ -301,10 +302,10 @@ template <typename OldTensorDescriptor,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewUpperDimensionNewVisibleIdss
>
__host__
__device__
constexpr
auto
transform_
dynamic_
tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
transform_tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
{
// sanity check
{
...
...
@@ -376,17 +377,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const
auto
element_space_size
=
old_tensor_desc
.
GetElementSpaceSize
();
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
};
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_
dynamic_
tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
VisibleIndex
&
idx_visible
)
__host__
__device__
constexpr
auto
make_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
VisibleIndex
&
idx_visible
)
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
...
...
@@ -416,13 +417,13 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
Dynamic
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
return
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
}
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
__host__
__device__
constexpr
auto
make_
dynamic_
tensor_coordinate_iterator
(
__host__
__device__
constexpr
auto
make_tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
...
...
@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
return
Dynamic
TensorCoordinateIterator
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
return
TensorCoordinateIterator
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
idx_diff_visible
,
do_transforms
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_
dynamic_
tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
make_tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
{
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
return
make_
dynamic_
tensor_coordinate_iterator
(
return
make_tensor_coordinate_iterator
(
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
}
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoordIterator
>
__host__
__device__
constexpr
void
move_dynamic_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
TensorCoordIterator
&
coord_iterator
)
__host__
__device__
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
TensorCoordIterator
&
coord_iterator
)
{
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
...
...
@@ -524,7 +526,7 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
// HACK: control UpdateLowerIndex for
Dynamic
Merge using hack
// HACK: control UpdateLowerIndex for Merge using hack
constexpr
index_t
Hack
=
decltype
(
coord_iterator
.
update_lower_index_hack_
)
::
At
(
itran
);
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
...
...
@@ -585,11 +587,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
}
template
<
typename
TensorDesc
>
using
Dynamic
TensorCoordinate_t
=
decltype
(
make_
dynamic_
tensor_coordinate
(
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
using
Dynamic
TensorCoordinateIterator_t
=
decltype
(
make_
dynamic_
tensor_coordinate_iterator
(
using
TensorCoordinateIterator_t
=
decltype
(
make_tensor_coordinate_iterator
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
...
...
composable_kernel/include/tensor_description/
dynamic_
tensor_descriptor_helper.hpp
→
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
View file @
c03045ce
#ifndef CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_
DYNAMIC_
TENSOR_DESCRIPTOR_HELPER_HPP
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "multi_index_transform_helper.hpp"
namespace
ck
{
...
...
@@ -38,9 +38,8 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template
<
typename
...
Lengths
,
typename
...
Strides
,
typename
std
::
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
make_dynamic_naive_tensor_descriptor_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
const
Tuple
<
Strides
...
>&
strides
)
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
const
Tuple
<
Strides
...
>&
strides
)
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
...
...
@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
Number
<
1
>
{});
#endif
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
}
// Lengths... can be:
...
...
@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
// 2) Number<>, which is known at compile-time
template
<
typename
...
Lengths
>
__host__
__device__
constexpr
auto
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
const
Tuple
<
Lengths
...
>&
lengths
)
make_naive_tensor_descriptor_packed
(
const
Tuple
<
Lengths
...
>&
lengths
)
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
...
...
@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
const
auto
element_space_size
=
container_reduce
(
lengths
,
math
::
multiplies_v2
{},
Number
<
1
>
{});
return
Dynamic
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
}
template
<
typename
...
Lengths
,
typename
Align
>
__host__
__device__
constexpr
auto
make_
dynamic_
naive_tensor_descriptor_aligned_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
make_naive_tensor_descriptor_aligned_v2
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
},
Number
<
N
>
{});
return
make_
dynamic_
naive_tensor_descriptor_v2
(
lengths
,
strides
);
return
make_naive_tensor_descriptor_v2
(
lengths
,
strides
);
}
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp
View file @
c03045ce
...
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
...
...
@@ -73,7 +73,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
__host__
__device__
static
constexpr
auto
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
/* a_k_m_block_desc */
)
{
const
auto
a_k_m0_m1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_k_m0_m1_block_desc
=
transform_tensor_descriptor
(
AKMBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
...
...
@@ -86,7 +86,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
/* b_k_n_block_desc */
)
{
const
auto
b_k_n0_n1_block_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_k_n0_n1_block_desc
=
transform_tensor_descriptor
(
BKNBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
...
...
@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
private:
// A[K, M0, M1]
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M11
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_k_n0_n1_block_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N11
,
1
>
;
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M11
,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_k_n0_n1_block_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N11
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp
View file @
c03045ce
...
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
...
...
@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
{
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_tensor_descriptor
(
a_block_desc_bk0_bm_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
...
...
@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
{
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_
dynamic_
tensor_descriptor
(
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_tensor_descriptor
(
b_block_desc_bk0_bn_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
...
...
@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
private:
// A[BK0, BM0, BM1, BK1]
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThreadBM11
>
{},
Number
<
BK1
>
{}));
// B[BK0, BN0, BN1, BK1]
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThreadBN11
>
{},
Number
<
BK1
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4r1
<
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
...
...
@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
Sequence
<
1
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4r1
<
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
c03045ce
...
...
@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
auto
a_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
1
>
;
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
1
>
;
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
c03045ce
...
...
@@ -2,7 +2,7 @@
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
namespace
ck
{
...
...
@@ -191,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
MRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
NRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
MRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
NRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
...
...
@@ -486,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// K1,
1
>
;
using
BThreadCopy
=
Threadwise
Dynamic
TensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// K1,
1
>
;
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// K1,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
// K1,
1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
...
...
composable_kernel/include/tensor_operation/blockwise_
dynamic_
tensor_slice_transfer.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
c03045ce
#ifndef CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. Threadwise
Dynamic
TensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. Threadwise
Dynamic
TensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
...
...
@@ -33,16 +33,16 @@ template <index_t BlockSize,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
Dynamic
TensorSliceTransfer_v4
struct
BlockwiseTensorSliceTransfer_v4
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
Blockwise
Dynamic
TensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
__device__
constexpr
BlockwiseTensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
...
...
@@ -147,22 +147,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
Threadwise
Dynamic
TensorSliceTransfer_v3
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTensorSliceTransfer_v3
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
...
...
composable_kernel/include/tensor_operation/blockwise_
dynamic_
tensor_slice_transfer_v2.hpp
→
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
View file @
c03045ce
#ifndef CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_
DYNAMIC_
TENSOR_SLICE_TRANSFER_V2_HPP
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_
dynamic_
tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. Threadwise
Dynamic
TensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. Threadwise
Dynamic
TensorSliceTransfer_v3::Run() does not construct new tensor coordinate
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
...
...
@@ -31,17 +31,16 @@ template <index_t BlockSize,
typename
DstVectorTensorContiguousDimOrder
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
Blockwise
Dynamic
TensorSliceTransfer_v4r1
struct
BlockwiseTensorSliceTransfer_v4r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseDynamicTensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
__device__
constexpr
BlockwiseTensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
...
...
@@ -136,20 +135,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
make_cluster_descriptor_v2
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
Threadwise
Dynamic
TensorSliceTransfer_v3r1
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorTensorLengths
,
DstVectorTensorLengths
,
SrcVectorTensorContiguousDimOrder
,
DstVectorTensorContiguousDimOrder
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTensorSliceTransfer_v3r1
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorTensorLengths
,
DstVectorTensorLengths
,
SrcVectorTensorContiguousDimOrder
,
DstVectorTensorContiguousDimOrder
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment