Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d8c89b68
"vscode:/vscode.git/clone" did not exist on "b3d5e87b4f27ea8f41acf6b8186a4a3fe3d6d68e"
Commit
d8c89b68
authored
May 10, 2021
by
Chao Liu
Browse files
refactor driver for conv
parent
fd160c63
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1360 additions
and
2643 deletions
+1360
-2643
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+419
-1525
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+281
-1014
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
+396
-0
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+21
-0
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+37
-8
driver/include/device.hpp
driver/include/device.hpp
+22
-4
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+94
-44
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+87
-45
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+3
-3
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
d8c89b68
...
@@ -4,57 +4,33 @@
...
@@ -4,57 +4,33 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp"
#include "driver_dynamic_gemm_v1.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace
ck
{
namespace
ck
{
// GemmM = K
// GemmM = K
// GemmN = N * Ho * Wo
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
// GemmK = C * Y * X
template
<
index_t
BlockSize
,
template
<
index_t
GemmMPerBlock
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmM1
,
index_t
GemmMPerThread
,
index_t
GemmN1
,
index_t
GemmNPerThread
,
typename
...
Wei
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmABlockTransferDstScalarPerVector_GemmM
,
typename
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
,
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
,
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
In
,
typename
...
Out
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
)
const
FloatAB
*
__restrict__
p_wei_global
,
{
const
FloatAB
*
__restrict__
p_in_global
,
FloatC
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -104,16 +80,15 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -104,16 +80,15 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
in_n_c_hip_wip_global_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_gemmk_gemmn_global_desc
=
in_n_c_y_ho_x_wo_global_desc
,
transform_dynamic_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
...
@@ -122,8 +97,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -122,8 +97,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
// output tensor
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -131,50 +105,42 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -131,50 +105,42 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
GemmK
%
GemmKPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
auto
GemmM1
=
Number
<
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
>
{};
constexpr
auto
GemmN1
=
Number
<
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
>
{};
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{}
;
const
auto
GemmN0
=
GemmN
/
GemmN1
;
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{}
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
//
c
_block_cluster_desc
//
out_gemm
_block_cluster_desc
const
auto
gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
const
auto
out_
gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over
a_k_
m_global tensor
// hack to control index calculation when iterating over
wei_gemmk_gemm
m_global tensor
constexpr
auto
a_k_
m_global_iterator_hacks
=
constexpr
auto
wei_gemmk_gemm
m_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
a_k_
m_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk_gemm
m_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over
b_k_
n_global tensor
// hack to control index calculation when iterating over
in_gemmk_gemm
n_global tensor
constexpr
auto
b_k_
n_global_iterator_hacks
=
constexpr
auto
in_gemmk_gemm
n_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
b_k_
n_global_move_slice_window_iterator_hack
=
constexpr
auto
in_gemmk_gemm
n_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over
c_m0_m1_n0_
n1_global
tensor
// hack to control index calculation when iterating over
out_gemmm0_gemmm1_gemmn0_gemm
n1_global
// hack for NKHW format
//
tensor
hack for NKHW format
constexpr
auto
c_m0_m1_n0_
n1_global_
tensor_
iterator_hacks
=
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemm
n1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
...
@@ -184,407 +150,41 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -184,407 +150,41 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// GEMM
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
decltype
(
gemm_block_cluster_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
a_k_m_global_iterator_hacks
),
decltype
(
b_k_n_global_iterator_hacks
),
decltype
(
c_m0_m1_n0_n1_global_tensor_iterator_hacks
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
decltype
(
gemm_block_cluster_desc
),
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
gemm_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
decltype
(
gemm_block_cluster_desc
),
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
gemm_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
decltype
(
gemm_block_cluster_desc
),
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
gemm_block_cluster_desc
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
decltype
(
gemm_block_cluster_desc
),
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
out_gemm_block_cluster_desc
,
gemm_block_cluster_desc
,
wei_gemmk_gemmm_global_iterator_hacks
,
integral_constant
<
bool
,
false
>
{},
in_gemmk_gemmn_global_iterator_hacks
,
integral_constant
<
bool
,
false
>
{});
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
,
}
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
}
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
wei_k_c_y_x_global_desc
,
out_n_k_ho_wo_global_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
true
,
true
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
true
,
false
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
false
,
true
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
false
,
false
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
wei_k_c_y_x_global_desc
,
out_n_k_ho_wo_global_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#endif
}
};
#if 0
// GemmM = K
// GemmM = K
// GemmN = N * Ho * Wo
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
// GemmK = C * Y * X
template <index_t BlockSize,
template
<
index_t
GemmMPerBlock
,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t GemmKPerBlock,
index_t
GemmM1
,
index_t GemmMPerThread,
index_t
GemmN1
,
index_t GemmNPerThread,
typename
...
Wei
,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{
template <typename... Wei,
typename
...
In
,
typename
...
In
,
typename
...
Out
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const InRightPads& in_right_pads,
const
InRightPads
&
in_right_pads
)
const FloatAB* __restrict__ p_wei_global,
{
const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -615,10 +215,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -615,10 +215,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
if(!(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0))
assert
(
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
);
{
throw std::runtime_error("wrong! no padding");
}
// weight tensor
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
...
@@ -630,16 +227,15 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -630,16 +227,15 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
// input tensor
// input tensor
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_c_hi_wi_global_desc
,
make_tuple(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform(N),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const auto in_gemmk_gemmn_global_desc =
transform_dynamic_tensor_descriptor(
const
auto
in_gemmk_gemmn_global_desc
=
in_n_c_y_ho_x_wo_global_desc,
transform_dynamic_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
...
@@ -648,8 +244,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -648,8 +244,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
// output tensor
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple(make_pass_through_transform(K),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -657,44 +252,40 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -657,44 +252,40 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
const auto GemmM0 = GemmM / GemmM1;
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{}
;
const auto GemmN0 = GemmN / GemmN1;
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{}
;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// out_gemm_block_cluster_desc
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over a_k_m_global tensor
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto
a_k_
m_global_iterator_hacks =
constexpr
auto
wei_gemmk_gemm
m_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr auto
a_k_
m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
constexpr
auto
wei_gemmk_gemm
m_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over b_k_n_global tensor
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr auto
b_k_
n_global_iterator_hacks =
make_tuple(
constexpr
auto
in_gemmk_gemm
n_global_iterator_hacks
=
make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr auto
b_k_
n_global_move_slice_window_iterator_hack =
constexpr
auto
in_gemmk_gemm
n_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
// hack for NKHW format
constexpr auto
c_m0_m1_n0_
n1_global_
tensor_
iterator_hacks =
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemm
n1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
...
@@ -704,394 +295,40 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -704,394 +295,40 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// GEMM
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc
,
in_gemmk_gemmn_global_desc
,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global,
out_gemm_block_cluster_desc
,
integral_constant<bool, true>{},
wei_gemmk_gemmm_global_iterator_hacks
,
integral_constant<bool, true>{});
in_gemmk_gemmn_global_iterator_hacks
,
}
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
{
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
const auto kernel =
}
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
// GemmM = K
const FloatAB*,
// GemmN = N * Ho * Wo
decltype(in_gemmk_gemmn_global_desc),
// GemmK = C * Y * X
const FloatAB*,
template
<
index_t
GemmMPerBlock
,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
}
;
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmM1
,
index_t
GemmMPerThread
,
index_t
GemmN1
,
index_t
GemmNPerThread
,
typename
...
Wei
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmABlockTransferDstScalarPerVector_GemmM
,
typename
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
,
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
,
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
In
,
typename
...
Out
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
)
const
FloatAB
*
__restrict__
p_wei_global
,
{
const
FloatAB
*
__restrict__
p_in_global
,
FloatC
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -1122,12 +359,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
@@ -1122,12 +359,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
if
(
!
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
assert
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
))
InRightPadW
==
0
);
{
throw
std
::
runtime_error
(
"wrong! 1x1, stride 1, no padding"
);
}
// weight tensor
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
...
@@ -1146,8 +380,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
@@ -1146,8 +380,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
// output tensor
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -1155,42 +388,38 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
@@ -1155,42 +388,38 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
GemmK
%
GemmKPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
const
expr
auto
GemmM
1
=
Number
<
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
>
{};
const
auto
GemmM
0
=
GemmM
/
Number
<
GemmM1
>
{};
const
expr
auto
GemmN
1
=
Number
<
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
>
{};
const
auto
GemmN
0
=
GemmN
/
Number
<
GemmN1
>
{};
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
GemmN0
=
GemmN
/
GemmN1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// out_gemm_block_cluster_desc
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over a_k_m_global tensor
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_k_
m_global_iterator_hacks
=
constexpr
auto
wei_gemmk_gemm
m_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
a_k_
m_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk_gemm
m_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over b_k_n_global tensor
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr
auto
b_k_
n_global_iterator_hacks
=
constexpr
auto
in_gemmk_gemm
n_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
1
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
2
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
2
>
{}));
constexpr
auto
b_k_
n_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
1
,
2
>
{};
constexpr
auto
in_gemmk_gemm
n_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr
auto
c_m0_m1_n0_
n1_global_
tensor_
iterator_hacks
=
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemm
n1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
...
@@ -1200,351 +429,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
@@ -1200,351 +429,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// GEMM
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
a_k_m_global_iterator_hacks
),
decltype
(
b_k_n_global_iterator_hacks
),
decltype
(
c_m0_m1_n0_n1_global_tensor_iterator_hacks
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
out_gemm_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
wei_gemmk_gemmm_global_iterator_hacks
,
integral_constant
<
bool
,
true
>
{});
in_gemmk_gemmn_global_iterator_hacks
,
}
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
,
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
{
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
const
auto
kernel
=
}
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
wei_k_c_y_x_global_desc
,
out_n_k_ho_wo_global_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
true
,
true
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
true
,
false
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
false
,
true
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
false
,
false
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
wei_k_c_y_x_global_desc
,
out_n_k_ho_wo_global_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#endif
}
};
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
d8c89b68
...
@@ -4,57 +4,33 @@
...
@@ -4,57 +4,33 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp"
#include "driver_dynamic_gemm_v1.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace
ck
{
namespace
ck
{
// GemmM = K
// GemmM = K
// GemmN = N * Ho * Wo
// GemmN = N * Ho * Wo
// GemmK = Y * X * C
// GemmK = C * Y * X
template
<
index_t
BlockSize
,
template
<
index_t
GemmMPerBlock
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmM1
,
index_t
GemmMPerThread
,
index_t
GemmN1
,
index_t
GemmNPerThread
,
typename
...
Wei
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmABlockTransferDstScalarPerVector_GemmM
,
typename
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
,
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
In
,
typename
...
Out
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_global_desc
,
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_hi_wi_c_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_hi_wi_c_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
)
const
FloatAB
*
__restrict__
p_wei_global
,
{
const
FloatAB
*
__restrict__
p_in_global
,
FloatC
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -104,16 +80,15 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
...
@@ -104,16 +80,15 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const
auto
in_n_y_ho_x_wo_c_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_c_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_hip_wip_c_global_desc
,
in_n_hip_wip_c_global_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_gemmk_gemmn_global_desc
=
in_n_y_ho_x_wo_c_global_desc
,
transform_dynamic_tensor_descriptor
(
in_n_y_ho_x_wo_c_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
...
@@ -130,50 +105,42 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
...
@@ -130,50 +105,42 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
GemmK
%
GemmKPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
auto
GemmM1
=
Number
<
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
>
{};
constexpr
auto
GemmN1
=
Number
<
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
>
{};
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{}
;
const
auto
GemmN0
=
GemmN
/
GemmN1
;
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{}
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
//
c
_block_cluster_desc
//
out_gemm
_block_cluster_desc
const
auto
gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
const
auto
out_
gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over a_k_m_global tensor
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_k_
m_global_iterator_hacks
=
constexpr
auto
wei_gemmk_gemm
m_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
a_k_
m_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk_gemm
m_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over b_k_n_global tensor
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr
auto
b_k_
n_global_iterator_hacks
=
constexpr
auto
in_gemmk_gemm
n_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
b_k_
n_global_move_slice_window_iterator_hack
=
constexpr
auto
in_gemmk_gemm
n_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
// hack for NKHW format
constexpr
auto
c_m0_m1_n0_
n1_global_
tensor_
iterator_hacks
=
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemm
n1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
...
@@ -183,402 +150,40 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
...
@@ -183,402 +150,40 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// GEMM
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
decltype
(
gemm_block_cluster_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmBBlockTransferSrcScalarPerVector_GemmK
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
1
,
GemmCThreadTransferDstScalarPerVector_GemmM1
,
decltype
(
a_k_m_global_iterator_hacks
),
decltype
(
b_k_n_global_iterator_hacks
),
decltype
(
c_m0_m1_n0_n1_global_tensor_iterator_hacks
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
printf
(
"%s: BlockSize %d, GridSize %d
\n
"
,
__func__
,
BlockSize
,
GridSize
);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
decltype
(
gemm_block_cluster_desc
),
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
gemm_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
decltype
(
gemm_block_cluster_desc
),
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
gemm_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
decltype
(
gemm_block_cluster_desc
),
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
gemm_block_cluster_desc
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
FloatAB
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
FloatAB
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
FloatC
*
,
decltype
(
gemm_block_cluster_desc
),
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
out_gemm_block_cluster_desc
,
gemm_block_cluster_desc
,
wei_gemmk_gemmm_global_iterator_hacks
,
integral_constant
<
bool
,
false
>
{},
in_gemmk_gemmn_global_iterator_hacks
,
integral_constant
<
bool
,
false
>
{});
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
,
}
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
}
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
true
,
true
>
;
launch_kernel
(
// GemmM = K
kernel
,
// GemmN = N * Ho * Wo
dim3
(
GridSize
),
// GemmK = C * Y * X
dim3
(
BlockSize
),
template
<
index_t
GemmMPerBlock
,
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
true
,
false
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
false
,
true
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
else
{
const
auto
kernel
=
run_gridwise_dynamic_gemm_v1
<
gridwise_gemm
,
ADesc
,
FloatAB
,
BDesc
,
FloatAB
,
CDesc
,
FloatC
,
false
,
false
>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
(
void
__CONSTANT__
*
)
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
(
void
__CONSTANT__
*
)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
);
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#endif
}
};
#if 0
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t GemmKPerBlock,
index_t
GemmM1
,
index_t GemmMPerThread,
index_t
GemmN1
,
index_t GemmNPerThread,
typename
...
Wei
,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmK,
index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmM1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
{
template <typename... Wei,
typename
...
In
,
typename
...
In
,
typename
...
Out
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc,
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_hi_wi_c_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_hi_wi_c_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const InRightPads& in_right_pads,
const
InRightPads
&
in_right_pads
)
const FloatAB* __restrict__ p_wei_global,
{
const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -609,12 +214,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
...
@@ -609,12 +214,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
if(!
(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
assert
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW == 0))
InRightPadW
==
0
);
{
throw std::runtime_error("wrong! 1x1, stride 1, no padding");
}
// weight tensor
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
...
@@ -641,42 +243,39 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
...
@@ -641,42 +243,39 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
const auto GemmM0 = GemmM / GemmM1;
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{}
;
const auto GemmN0 = GemmN / GemmN1;
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{}
;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// hack to control index calculation when iterating over a_k_m_global tensor
// out_gemm_block_cluster_desc
constexpr auto a_k_m_global_iterator_hacks =
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global_iterator_hacks
// tensor
constexpr
auto
wei_gemmk_gemmm_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr auto
a_k_
m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
constexpr
auto
wei_gemmk_gemm
m_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over b_k_n_global tensor
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr auto
b_k_
n_global_iterator_hacks =
constexpr
auto
in_gemmk_gemm
n_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr auto
b_k_
n_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
constexpr
auto
in_gemmk_gemm
n_global_move_slice_window_iterator_hack
s
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto
c_m0_m1_n0_
n1_global_
tensor_
iterator_hacks =
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemm
n1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -686,348 +285,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
...
@@ -686,348 +285,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// GEMM
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc
,
in_gemmk_gemmn_global_desc
,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global,
out_gemm_block_cluster_desc
,
integral_constant<bool, true>{},
wei_gemmk_gemmm_global_iterator_hacks
,
integral_constant<bool, true>{});
in_gemmk_gemmn_global_iterator_hacks
,
}
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
,
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
{
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
const auto kernel =
}
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
}
;
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
0 → 100644
View file @
d8c89b68
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
#define CK_DRIVER_DYNAMIC_GEMM_V1
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_M
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_N
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobalIteratorHacks
,
typename
BGlobalIteratorHacks
,
typename
CGlobalIteratorHacks
,
typename
AGlobalMoveSliceWindowIteratorHacks
,
typename
BGlobalMoveSliceWindowIteratorHacks
>
__host__
float
launch_kernel_dynamic_gemm_v1
(
const
FloatAB
*
p_a_global
,
const
FloatAB
*
p_b_global
,
FloatC
*
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
AGlobalIteratorHacks
,
BGlobalIteratorHacks
,
CGlobalIteratorHacks
,
AGlobalMoveSliceWindowIteratorHacks
,
BGlobalMoveSliceWindowIteratorHacks
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
auto
M1
=
Number
<
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
>
{};
constexpr
auto
N1
=
Number
<
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
>
{};
if
(
!
(
MPerBlock
%
M1
==
0
&&
NPerBlock
%
N1
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
// GEMM
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
AGlobalDesc
,
BGlobalDesc
,
CGlobalDesc
,
CBlockClusterDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGlobalIteratorHacks
,
BGlobalIteratorHacks
,
CGlobalIteratorHacks
,
AGlobalMoveSliceWindowIteratorHacks
,
BGlobalMoveSliceWindowIteratorHacks
>
;
const
auto
GridSize
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
bool
has_main_k_block_loop
=
(
K
+
KPerBlock
)
/
(
2
*
KPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k_m_global_desc_device_buf
(
sizeof
(
AGlobalDesc
));
DeviceMem
b_k_n_global_desc_device_buf
(
sizeof
(
BGlobalDesc
));
DeviceMem
c_m0_m1_n0_n1_global_desc_device_buf
(
sizeof
(
CGlobalDesc
));
DeviceMem
c_block_cluster_desc_device_buf
(
sizeof
(
c_block_cluster_desc
));
a_k_m_global_desc_device_buf
.
ToDevice
(
&
a_k_m_global_desc
);
b_k_n_global_desc_device_buf
.
ToDevice
(
&
b_k_n_global_desc
);
c_m0_m1_n0_n1_global_desc_device_buf
.
ToDevice
(
&
c_m0_m1_n0_n1_global_desc
);
c_block_cluster_desc_device_buf
.
ToDevice
(
&
c_block_cluster_desc
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
return
ave_time
;
#endif
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
d8c89b68
...
@@ -184,6 +184,27 @@ struct TensorAdaptor
...
@@ -184,6 +184,27 @@ struct TensorAdaptor
return
get_container_subset
(
idx_hidden
,
BottomDimensionHiddenIds
{});
return
get_container_subset
(
idx_hidden
,
BottomDimensionHiddenIds
{});
}
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"TensorAdaptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
printf
(
"LowerDimensionHiddenIds:"
);
LowerDimensionHiddenIdss
{}.
At
(
i
).
Print
();
printf
(
"UpperDimensionHiddenIds:"
);
UpperDimensionHiddenIdss
{}.
At
(
i
).
Print
();
});
printf
(
"BottomDimensionHiddenIds:"
);
BottomDimensionHiddenIds
::
Print
();
printf
(
"TopDimensionHiddenIds:"
);
TopDimensionHiddenIds
::
Print
();
printf
(
"}"
);
}
private:
private:
Transforms
transforms_
;
Transforms
transforms_
;
ElementSize
element_size_
;
ElementSize
element_size_
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
d8c89b68
...
@@ -12,7 +12,36 @@
...
@@ -12,7 +12,36 @@
namespace
ck
{
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
kernel_dynamic_gemm_v1
(
const
AGlobalDesc
a_k_m_global_desc
,
const
FloatA
*
__restrict__
p_a_global
,
const
BGlobalDesc
b_k_n_global_desc
,
const
FloatB
*
__restrict__
p_b_global
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
const
CBlockClusterDesc
c_block_cluster_desc
)
{
GridwiseGemm
{}.
Run
(
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
// non-modifiable parameter address space, so compiler can enable corresponding optimization
...
@@ -26,7 +55,7 @@ template <typename GridwiseGemm,
...
@@ -26,7 +55,7 @@ template <typename GridwiseGemm,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
run_gridwise
_dynamic_gemm_v1
(
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
__global__
void
kernel
_dynamic_gemm_v1
(
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
FloatA
*
__restrict__
p_a_global
,
const
FloatA
*
__restrict__
p_a_global
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
FloatB
*
__restrict__
p_b_global
,
const
FloatB
*
__restrict__
p_b_global
,
...
...
driver/include/device.hpp
View file @
d8c89b68
...
@@ -46,6 +46,7 @@ void launch_kernel(F kernel,
...
@@ -46,6 +46,7 @@ void launch_kernel(F kernel,
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
F
kernel
,
float
launch_and_time_kernel
(
F
kernel
,
int
nrepeat
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
std
::
size_t
lds_byte
,
...
@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel,
...
@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel,
{
{
KernelTimer
timer
;
KernelTimer
timer
;
printf
(
"%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up
\n
"
);
// warm up
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
timer
.
Start
();
timer
.
Start
();
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
}
timer
.
End
();
timer
.
End
();
hipGetLastError
();
return
timer
.
GetElapsedTime
()
/
nrepeat
;
return
timer
.
GetElapsedTime
();
}
}
#elif CK_DEVICE_BACKEND_NVIDIA
#elif CK_DEVICE_BACKEND_NVIDIA
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
d8c89b68
...
@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
...
@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
{
{
using
namespace
ck
;
using
namespace
ck
;
std
::
cout
<<
"device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
std
::
cout
<<
__func__
<<
std
::
endl
;
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
...
@@ -459,18 +468,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
...
@@ -459,18 +468,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#endif
#endif
constexpr
auto
conv_driver
=
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
auto
descs
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitG
emm_v4r4_nchw_kcyx_nkhw_pad
transform_forward_convolution_into_g
emm_v4r4_nchw_kcyx_nkhw_pad
#elif 0
#elif 0
DriverDynamicConvolutionForwardImplicitG
emm_v4r4_nchw_kcyx_nkhw_no_pad
transform_forward_convolution_into_g
emm_v4r4_nchw_kcyx_nkhw_no_pad
#el
if 1
#el
se
DriverDynamicConvolutionForwardImplicitG
emm_v4r4_nchw_kcyx_nkhw_1x1
transform_forward_convolution_into_g
emm_v4r4_nchw_kcyx_nkhw_1x1
#endif
#endif
<
BlockSize
,
<
GemmMPerBlock
,
GemmNPerBlock
,
GemmM1
,
GemmN1
>
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
float
ave_time
=
launch_kernel_dynamic_gemm_v1
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TAcc
,
TOut
,
TOut
,
InMemoryDataOperation
::
Set
,
decltype
(
descs
[
I0
]),
decltype
(
descs
[
I1
]),
decltype
(
descs
[
I2
]),
decltype
(
descs
[
I3
]),
GemmMPerBlock
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmKPerBlock
,
...
@@ -483,26 +509,50 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
...
@@ -483,26 +509,50 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
{};
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
conv_driver
.
Run
(
wei_k_c_y_x_desc
,
Sequence
<
2
,
3
,
0
,
1
>
,
in_n_c_hi_wi_desc
,
3
,
out_n_k_ho_wo_desc
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
conv_strides
,
decltype
(
descs
[
I4
])
,
conv_dilations
,
decltype
(
descs
[
I5
])
,
in_left_pads
,
decltype
(
descs
[
I6
])
,
in_right_pads
,
decltype
(
descs
[
I7
])
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
decltype
(
descs
[
I8
])
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()));
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
descs
[
I3
],
descs
[
I4
],
descs
[
I5
],
descs
[
I6
],
descs
[
I7
],
descs
[
I8
],
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
// copy result back to host
out_n_k_ho_wo_device_buf
.
FromDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
FromDevice
(
out_n_k_ho_wo
.
mData
.
data
());
}
}
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
d8c89b68
...
@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
...
@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
{
{
using
namespace
ck
;
using
namespace
ck
;
std
::
cout
<<
"device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk"
std
::
cout
<<
__func__
<<
std
::
endl
;
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
constexpr
auto
N
=
OutDesc
::
GetLengths
()[
I0
];
constexpr
auto
N
=
OutDesc
::
GetLengths
()[
I0
];
constexpr
auto
K
=
OutDesc
::
GetLengths
()[
I1
];
constexpr
auto
K
=
OutDesc
::
GetLengths
()[
I1
];
...
@@ -372,18 +376,33 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
...
@@ -372,18 +376,33 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#endif
#endif
constexpr
auto
conv_driver
=
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
auto
descs
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
#elif 0
#else
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif
#endif
<
BlockSize
,
<
GemmMPerBlock
,
GemmNPerBlock
,
GemmM1
,
GemmN1
>
(
wei_k_y_x_c0_desc
,
in_n_hi_wi_c0_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
float
ave_time
=
launch_kernel_dynamic_gemm_v1
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TAcc
,
TOut
,
TOut
,
InMemoryDataOperation
::
Set
,
decltype
(
descs
[
I0
]),
decltype
(
descs
[
I1
]),
decltype
(
descs
[
I2
]),
decltype
(
descs
[
I3
]),
GemmMPerBlock
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmKPerBlock
,
...
@@ -396,27 +415,50 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
...
@@ -396,27 +415,50 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmBBlockTransferSrcScalarPerVector_GemmK
,
GemmBBlockTransferSrcScalarPerVector_GemmK
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmCThreadTransferDstScalarPerVector_GemmM1
>
{};
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
conv_driver
.
Run
(
wei_k_y_x_c0_desc
,
Sequence
<
2
,
3
,
0
,
1
>
,
in_n_hi_wi_c0_desc
,
1
,
out_n_ho_wo_k_desc
,
GemmCThreadTransferDstScalarPerVector_GemmM1
,
conv_strides
,
decltype
(
descs
[
I4
])
,
conv_dilations
,
decltype
(
descs
[
I5
])
,
in_left_pads
,
decltype
(
descs
[
I6
])
,
in_right_pads
,
decltype
(
descs
[
I7
])
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
decltype
(
descs
[
I8
])
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()));
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
descs
[
I3
],
descs
[
I4
],
descs
[
I5
],
descs
[
I6
],
descs
[
I7
],
descs
[
I8
],
nrepeat
);
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
// copy result back to host
out_n_ho_wo_k_device_buf
.
FromDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
FromDevice
(
out_n_ho_wo_k
.
mData
.
data
());
auto
f_nhwk2nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nhwk2nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
...
...
driver/src/conv_driver.cpp
View file @
d8c89b68
...
@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
...
@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
0
#elif
1
// 3x3, 71x71
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
C
=
192
;
...
@@ -225,7 +225,7 @@ int main(int argc, char* argv[])
...
@@ -225,7 +225,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
1
#elif
0
// 7x1, 17x17
// 7x1, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
...
@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif
0
#elif
1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
<
in_data_t
,
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
<
in_data_t
,
in_vector_size
,
in_vector_size
,
acc_data_t
,
acc_data_t
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment