Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
758f576a
Commit
758f576a
authored
Mar 13, 2021
by
root
Browse files
parameters clean
parent
6e59255a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
63 deletions
+45
-63
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+22
-31
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+7
-11
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+16
-21
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
758f576a
...
...
@@ -9,22 +9,17 @@
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
KPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
CYXPerBlock
,
index_t
KPerThread
,
index_t
HPerThread
,
index_t
WPerThread
,
index_t
CYXPerThread
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
,
...
...
@@ -130,12 +125,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
GemmM
=
K
;
const
auto
GemmN
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
C
*
Y
*
X
;
const
auto
CYX
=
C
*
Y
*
X
;
if
(
!
(
GemmM
%
GemmM
PerBlock
==
0
&&
GemmN
%
GemmN
PerBlock
==
0
&&
GemmK
%
GemmK
PerBlock
==
0
))
if
(
!
(
K
%
K
PerBlock
==
0
&&
Ho
%
HPerBlock
==
0
&&
Wo
%
W
PerBlock
==
0
&&
CYX
%
CYX
PerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
...
...
@@ -182,16 +175,14 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_n_ho_wo_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
KPerBlock
,
HPerBlock
,
WPerBlock
,
CYXPerBlock
,
KPerThread
,
HPerThread
,
WPerThread
,
CYXPerThread
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
...
...
@@ -218,11 +209,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmN
PerBlock
);
const
auto
GridSize
=
(
K
/
KPerBlock
)
*
(
Ho
/
HPerBlock
)
*
(
Wo
/
W
PerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmK
PerBlock
)
/
(
2
*
GemmK
PerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
(
CYX
+
CYX
PerBlock
)
/
(
2
*
CYX
PerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmK
PerBlock
)
%
2
==
0
;
const
bool
has_double_tail_k_block_loop
=
(
CYX
/
CYX
PerBlock
)
%
2
==
0
;
index_t
nrepeat
=
100
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
758f576a
...
...
@@ -19,15 +19,13 @@ template <index_t BlockSize,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
index_t
KPerBlock
,
index_t
HWPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
CYXPerBlock
,
index_t
KPerThread
,
index_t
HWPerThread
,
index_t
HPerThread
,
index_t
WPerThread
,
index_t
CYXPerThread
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
...
...
@@ -99,7 +97,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// divide block work by [M, N]
#if 1
const
auto
m_block_work_num
=
K
/
Number
<
KPerBlock
>
{};
const
auto
nhw_block_work_num
=
(
N
*
H
*
W
)
/
Number
<
HWPerBlock
>
{};
const
auto
nhw_block_work_num
=
(
N
*
H
*
W
)
/
(
Number
<
H
PerBlock
>
{}
*
Number
<
WPerBlock
>
{}
)
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
nhw_block_work_num
;
const
index_t
nhw_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
nhw_block_work_num
;
...
...
@@ -120,10 +118,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const
index_t
w_block_data_on_global
=
nhw_block_work_id
*
8
;
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
KPerThread
>
{},
Number
<
HWPerThread
>
{});
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
KPerThread
>
{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
758f576a
...
...
@@ -70,18 +70,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
*
3
*
3
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HPerBlock
=
8
;
constexpr
index_t
WPerBlock
=
8
;
constexpr
index_t
CYXPerBlock
=
4
*
3
*
3
;
constexpr
index_t
GemmMPerThread
=
16
;
constexpr
index_t
GemmNPerThread
=
1
;
constexpr
index_t
GemmKPerThread
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
1
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
64
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
HPerThread
=
1
;
constexpr
index_t
WPerThread
=
1
;
constexpr
index_t
CYXPerThread
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
9
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
...
...
@@ -102,16 +99,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
BlockSize
,
TDevice
,
TDevice
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
KPerBlock
,
HPerBlock
,
WPerBlock
,
CYXPerBlock
,
KPerThread
,
HPerThread
,
WPerThread
,
CYXPerThread
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment