Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1311e039
Commit
1311e039
authored
Jun 03, 2021
by
Chao Liu
Browse files
refactor
parent
49b926b6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
103 additions
and
99 deletions
+103
-99
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+67
-63
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+36
-36
No files found.
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
1311e039
...
@@ -59,31 +59,31 @@ template <index_t BlockSize,
...
@@ -59,31 +59,31 @@ template <index_t BlockSize,
typename
AKMGridDesc
,
typename
AKMGridDesc
,
typename
BKNGridDesc
,
typename
BKNGridDesc
,
typename
CMNGridDesc
,
typename
CMNGridDesc
,
index_t
MPerBlock
,
index_t
MPerBlock
M1
,
index_t
NPerBlock
,
index_t
NPerBlock
N1
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
M1PerThread
,
index_t
M1PerThread
M111
,
index_t
N1PerThread
,
index_t
N1PerThread
N111
,
index_t
KPerThread
,
index_t
KPerThread
,
index_t
M1
N
1ThreadClusterM100
,
index_t
M1
1N1
1ThreadClusterM1
1
00
,
index_t
M1
N
1ThreadClusterN100
,
index_t
M1
1N1
1ThreadClusterN1
1
00
,
index_t
M1
N
1ThreadClusterM101
,
index_t
M1
1N1
1ThreadClusterM1
1
01
,
index_t
M1
N
1ThreadClusterN101
,
index_t
M1
1N1
1ThreadClusterN1
1
01
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadSliceLengths_K_M
0_M1
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
0_M1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_M
,
index_t
ABlockTransferDstScalarPerVector_M
1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadSliceLengths_K_N
0_N1
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
0_N1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_N
,
index_t
BBlockTransferDstScalarPerVector_N
1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
...
@@ -102,20 +102,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -102,20 +102,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
1
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
1
>
{},
Number
<
M1PerThread
>
{},
Number
<
M1PerThread
M111
>
{},
Number
<
N1PerThread
>
{});
Number
<
N1PerThread
N111
>
{});
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
M1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
N1
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
constexpr
auto
a_block_aligned_space_size
=
...
@@ -139,12 +139,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -139,12 +139,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K
==
b_k_n_grid_desc
.
GetLength
(
I0
))
&&
K
==
b_k_n_grid_desc
.
GetLength
(
I0
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
);
(
M
%
MPerBlock
M1
==
0
&&
N
%
NPerBlock
N1
==
0
&&
K
%
KPerBlock
==
0
);
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
{
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
index_t
grid_size
=
(
M
/
MPerBlock
M1
)
*
(
N
/
NPerBlock
N1
);
return
grid_size
;
return
grid_size
;
}
}
...
@@ -169,7 +169,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -169,7 +169,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const
auto
K
=
a_k_m_grid_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M1
=
Number
<
MPerBlock
M1
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
a_k_m0_m1_grid_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
a_k_m0_m1_grid_desc
=
transform_dynamic_tensor_descriptor
(
...
@@ -187,7 +187,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -187,7 +187,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const
auto
K
=
b_k_n_grid_desc
.
GetLength
(
I0
);
const
auto
K
=
b_k_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N1
=
Number
<
NPerBlock
N1
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
const
auto
b_k_n0_n1_grid_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
b_k_n0_n1_grid_desc
=
transform_dynamic_tensor_descriptor
(
...
@@ -205,14 +205,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -205,14 +205,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
M1
=
Number
<
MPerBlock
M1
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
N1
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
Number
<
M1N1ThreadClusterM100
*
M1N1ThreadClusterM101
*
M1PerThread
>
{};
constexpr
auto
M11
=
constexpr
auto
N11
=
Number
<
M1N1ThreadClusterN100
*
M1N1ThreadClusterN101
*
N1PerThread
>
{};
Number
<
M11N11ThreadClusterM1100
*
M11N11ThreadClusterM1101
*
M1PerThreadM111
>
{};
constexpr
auto
N11
=
Number
<
M11N11ThreadClusterN1100
*
M11N11ThreadClusterN1101
*
N1PerThreadN111
>
{};
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
constexpr
auto
N10
=
N1
/
N11
;
...
@@ -233,8 +235,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -233,8 +235,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
M1
=
Number
<
MPerBlock
M1
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
N1
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
...
@@ -285,38 +287,38 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -285,38 +287,38 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const
index_t
n0_idx
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
const
index_t
n0_idx
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
1
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
1
>
{},
Number
<
M1PerThread
>
{},
Number
<
M1PerThread
M111
>
{},
Number
<
N1PerThread
>
{});
Number
<
N1PerThread
N111
>
{});
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
M1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
N1
>
{}),
max_lds_align
);
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m0_m1_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m0_m1_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlock
M1
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n0_n1_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n0_n1_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlock
N1
>
{}),
max_lds_align
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlock
>
,
Sequence
<
KPerBlock
,
1
,
MPerBlock
M1
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadSliceLengths_K_M
0_M1
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
0_M1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
...
@@ -327,7 +329,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -327,7 +329,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
ABlockTransferDstScalarPerVector_M
1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -340,9 +342,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -340,9 +342,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
KPerBlock
,
1
,
NPerBlock
N1
>
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadSliceLengths_K_N
0_N1
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
0_N1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
...
@@ -353,7 +355,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -353,7 +355,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BBlockTransferDstScalarPerVector_N
1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -364,9 +366,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -364,9 +366,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// a_mtx[KPerBlock, MPerBlock
M1
] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock
N1
] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock
M1
, NPerBlock
N1
] is distributed among threads, and saved in
// register
// register
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
<
BlockSize
,
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
<
BlockSize
,
...
@@ -375,15 +377,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -375,15 +377,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
FloatAcc
,
FloatAcc
,
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
),
M1PerThread
,
M1PerThread
M111
,
N1PerThread
,
N1PerThread
N111
,
KPerThread
,
KPerThread
,
M1
N
1ThreadClusterM100
,
M1
1N1
1ThreadClusterM1
1
00
,
M1
N
1ThreadClusterN100
,
M1
1N1
1ThreadClusterN1
1
00
,
M1
N
1ThreadClusterM101
,
M1
1N1
1ThreadClusterM1
1
01
,
M1
N
1ThreadClusterN101
,
M1
1N1
1ThreadClusterN1
1
01
,
M1PerThread
,
M1PerThread
M111
,
N1PerThread
>
{};
N1PerThread
N111
>
{};
constexpr
auto
c_m10_n10_m11_n11_thread_tensor_lengths
=
constexpr
auto
c_m10_n10_m11_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCM0M1N0N1ThreadTensorLengths
();
decltype
(
blockwise_gemm
)
::
GetCM0M1N0N1ThreadTensorLengths
();
...
@@ -559,14 +561,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
...
@@ -559,14 +561,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
// output: register to global memory
// output: register to global memory
{
{
constexpr
index_t
M11
=
M1PerThread
*
M1N1ThreadClusterM100
*
M1N1ThreadClusterM101
;
constexpr
index_t
M11
=
constexpr
index_t
N11
=
N1PerThread
*
M1N1ThreadClusterN100
*
M1N1ThreadClusterN101
;
M1PerThreadM111
*
M11N11ThreadClusterM1100
*
M11N11ThreadClusterM1101
;
constexpr
index_t
N11
=
N1PerThreadN111
*
M11N11ThreadClusterN1100
*
M11N11ThreadClusterN1101
;
constexpr
index_t
M10
=
MPerBlock
/
M11
;
constexpr
index_t
M10
=
MPerBlock
M1
/
M11
;
constexpr
index_t
N10
=
NPerBlock
/
N11
;
constexpr
index_t
N10
=
NPerBlock
N1
/
N11
;
constexpr
index_t
M111
=
M1PerThread
;
constexpr
index_t
M111
=
M1PerThread
M111
;
constexpr
index_t
N111
=
N1PerThread
;
constexpr
index_t
N111
=
N1PerThread
N111
;
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
constexpr
auto
c_m0_m10_m11_n0_n10_n11_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
1311e039
...
@@ -83,32 +83,32 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -83,32 +83,32 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
// b thread copy 4x1
// b thread copy 4x1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
M1
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
N1
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmM
1
PerThread
M111
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmN
1
PerThread
N111
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmM
Level0
Cluster
=
2
;
constexpr
index_t
GemmM
11N11Thread
Cluster
M1101
=
2
;
constexpr
index_t
Gemm
NLevel0
Cluster
=
2
;
constexpr
index_t
Gemm
M11N11Thread
Cluster
N1101
=
2
;
constexpr
index_t
GemmM
Level1
Cluster
=
8
;
constexpr
index_t
GemmM
11N11Thread
Cluster
M1100
=
8
;
constexpr
index_t
Gemm
NLevel1
Cluster
=
8
;
constexpr
index_t
Gemm
M11N11Thread
Cluster
N1100
=
8
;
using
GemmABlockTransferThreadSliceLengths_
GemmK_GemmM0_Gemm
M1
=
Sequence
<
4
,
1
,
1
>
;
using
GemmABlockTransferThreadSliceLengths_
K_M0_
M1
=
Sequence
<
4
,
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_
GemmK_GemmM0_Gemm
M1
=
Sequence
<
2
,
1
,
128
>
;
using
GemmABlockTransferThreadClusterLengths_
K_M0_
M1
=
Sequence
<
2
,
1
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_
Gemm
K
=
4
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_K
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_
Gemm
M1
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_M1
=
1
;
using
GemmBBlockTransferThreadSliceLengths_
GemmK_GemmN0_Gemm
N1
=
Sequence
<
4
,
1
,
1
>
;
using
GemmBBlockTransferThreadSliceLengths_
K_N0_
N1
=
Sequence
<
4
,
1
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_
GemmK_GemmN0_Gemm
N1
=
Sequence
<
2
,
1
,
128
>
;
using
GemmBBlockTransferThreadClusterLengths_
K_N0_
N1
=
Sequence
<
2
,
1
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_
Gemm
N1
=
1
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_N1
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_
Gemm
N1
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_N1
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_
Gemm
N11
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_N11
=
1
;
#endif
#endif
const
auto
descs
=
const
auto
descs
=
...
@@ -174,35 +174,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -174,35 +174,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
decltype
(
wei_gemmk_gemmm_grid_desc
),
decltype
(
wei_gemmk_gemmm_grid_desc
),
decltype
(
in_gemmk_gemmn_grid_desc
),
decltype
(
in_gemmk_gemmn_grid_desc
),
decltype
(
out_gemmm_gemmn_grid_desc
),
decltype
(
out_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmMPerBlock
M1
,
GemmNPerBlock
,
GemmNPerBlock
N1
,
GemmKPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmM
1
PerThread
M111
,
GemmNPerThread
,
GemmN
1
PerThread
N111
,
GemmKPerThread
,
GemmKPerThread
,
GemmM
Level1
Cluster
,
GemmM
11N11Thread
Cluster
M1100
,
Gemm
NLevel1
Cluster
,
Gemm
M11N11Thread
Cluster
N1100
,
GemmM
Level0
Cluster
,
GemmM
11N11Thread
Cluster
M1101
,
Gemm
NLevel0
Cluster
,
Gemm
M11N11Thread
Cluster
N1101
,
GemmABlockTransferThreadSliceLengths_
GemmK_GemmM0_Gemm
M1
,
GemmABlockTransferThreadSliceLengths_
K_M0_
M1
,
GemmABlockTransferThreadClusterLengths_
GemmK_GemmM0_Gemm
M1
,
GemmABlockTransferThreadClusterLengths_
K_M0_
M1
,
Sequence
<
2
,
1
,
0
>
,
// ABlockTransferThreadClusterArrangeOrder
Sequence
<
2
,
1
,
0
>
,
// ABlockTransferThreadClusterArrangeOrder
Sequence
<
2
,
1
,
0
>
,
// ABlockTransferSrcAccessOrder
Sequence
<
2
,
1
,
0
>
,
// ABlockTransferSrcAccessOrder
0
,
// ABlockTransferSrcVectorDim
0
,
// ABlockTransferSrcVectorDim
GemmABlockTransferSrcScalarPerVector_
Gemm
K
,
GemmABlockTransferSrcScalarPerVector_K
,
GemmABlockTransferDstScalarPerVector_
Gemm
M1
,
GemmABlockTransferDstScalarPerVector_M1
,
false
,
// don't move back src coordinate after threadwise copy
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_
GemmK_GemmN0_Gemm
N1
,
GemmBBlockTransferThreadSliceLengths_
K_N0_
N1
,
GemmBBlockTransferThreadClusterLengths_
GemmK_GemmN0_Gemm
N1
,
GemmBBlockTransferThreadClusterLengths_
K_N0_
N1
,
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferSrcAccessOrder
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcVectorDim
GemmBBlockTransferSrcScalarPerVector_
Gemm
N1
,
GemmBBlockTransferSrcScalarPerVector_N1
,
GemmBBlockTransferDstScalarPerVector_
Gemm
N1
,
GemmBBlockTransferDstScalarPerVector_N1
,
false
,
// don't move back src coordinate after threadwise copy
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
5
,
// CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_
Gemm
N11
,
GemmCThreadTransferDstScalarPerVector_N11
,
decltype
(
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks
),
decltype
(
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks
),
decltype
(
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks
),
decltype
(
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment