Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
95710403
Commit
95710403
authored
Jun 02, 2021
by
Jing Zhang
Browse files
add kpack with incorrect results
parent
44078dba
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
220 additions
and
217 deletions
+220
-217
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+45
-23
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+147
-140
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+28
-54
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
95710403
...
...
@@ -64,6 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
C
*
Y
*
X
;
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
...
...
@@ -71,6 +75,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_global_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmm_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK
/
GemmKPack
,
GemmKPack
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
...
...
@@ -97,6 +108,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_global_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK
/
GemmKPack
,
GemmKPack
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
...
...
@@ -104,11 +122,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
assert
(
GemmM
=
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
)
)
;
assert
(
GemmN
=
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
)
)
;
const
auto
GemmK
0
=
wei_gemmk
0
_gemmm_
gemmk1_
global_desc
.
GetLength
(
I0
);
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
0
%
GemmKPerBlock
==
0
);
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPack
>
{};
...
...
@@ -129,22 +147,26 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr
auto
wei_gemmk_gemmm_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr
auto
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr
auto
in_gemmk_gemmn_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr
auto
in_gemmk0_gemmn_gemmk1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
constexpr
auto
in_gemmk
0
_gemmn_
gemmk1_
global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
...
...
@@ -158,15 +180,15 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
return
make_tuple
(
wei_gemmk
0
_gemmm_
gemmk1_
global_desc
,
in_gemmk
0
_gemmn_
gemmk1_
global_desc
,
out_m0_m1_m2_n_global_desc
,
out_gemm_block_cluster_desc
,
wei_gemmk_gemmm_global_iterator_hacks
,
in_gemmk_gemmn_global_iterator_hacks
,
wei_gemmk
0
_gemmm_
gemmk1_
global_iterator_hacks
,
in_gemmk
0
_gemmn_
gemmk1_
global_iterator_hacks
,
out_m0_m1_m2_n_global_iterator_hacks
,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
wei_gemmk
0
_gemmm_
gemmk1_
global_move_slice_window_iterator_hacks
,
in_gemmk
0
_gemmn_
gemmk1_
global_move_slice_window_iterator_hacks
);
}
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
95710403
...
...
@@ -30,16 +30,16 @@ __global__ void
kernel_dynamic_gemm_xdlops_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
a_k_m_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
AGlobalDesc
a_k
0
_m_
k1_
global_desc
,
const
BGlobalDesc
b_k
0
_n_
k1_
global_desc
,
const
CGlobalDesc
c_m0_m1_m2_n_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
{
GridwiseGemm
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
a_k
0
_m_
k1_
global_desc
,
b_k
0
_n_
k1_
global_desc
,
c_m0_m1_m2_n_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
@@ -66,18 +66,18 @@ __global__ void
kernel_dynamic_gemm_xdlops_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_a_k
0
_m_
k1_
global_desc
,
const
void
__CONSTANT__
*
p_b_k
0
_n_
k1_
global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_m2_n_global_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_k_m_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
((
const
void
*
)
p_a_k_m_global_desc
);
const
auto
b_k_n_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
((
const
void
*
)
p_b_k_n_global_desc
);
const
auto
a_k
0
_m_
k1_
global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
((
const
void
*
)
p_a_k
0
_m_
k1_
global_desc
);
const
auto
b_k
0
_n_
k1_
global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
((
const
void
*
)
p_b_k
0
_n_
k1_
global_desc
);
const
auto
c_m0_m1_m2_n_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
((
const
void
*
)
p_c_m0_m1_m2_n_global_desc
);
...
...
@@ -87,8 +87,8 @@ __global__ void
GridwiseGemm
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
a_k
0
_m_
k1_
global_desc
,
b_k
0
_n_
k1_
global_desc
,
c_m0_m1_m2_n_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
@@ -113,21 +113,21 @@ template <index_t BlockSize,
index_t
KPack
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadSliceLengths_K_M
_KPack
,
typename
ABlockTransferThreadClusterLengths_K_M
_KPack
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_
M
,
index_t
ABlockTransferDstScalarPerVector_
KPack
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadSliceLengths_K_N
_KPack
,
typename
BBlockTransferThreadClusterLengths_K_N
_KPack
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_
N
,
index_t
BBlockTransferDstScalarPerVector_
KPack
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
...
...
@@ -141,25 +141,26 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{});
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_KPack
>
{},
Number
<
BBlockTransferDstScalarPerVector_KPack
>
{},
Number
<
KPack
>
{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}),
max_lds_align
);
constexpr
auto
a_k
0
_m_
k1_
block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}
,
Number
<
KPack
>
{}
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
constexpr
auto
b_k
0
_n_
k1_
block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}
,
Number
<
KPack
>
{}
),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k
0
_m_
k1_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k
0
_n_
k1_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
}
...
...
@@ -168,8 +169,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
AGlobalDesc
&
a_k
0
_m_
k1_
global_desc
,
const
BGlobalDesc
&
b_k
0
_n_
k1_
global_desc
,
const
CGlobalDesc
&
c_m0_m1_m2_n_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
FloatAB
*
__restrict__
p_shared_block
,
...
...
@@ -182,15 +183,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
,
a_k_m_global_desc
.
GetElementSpaceSize
());
p_a_global
,
a_k
0
_m_
k1_
global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_k_n_global_desc
.
GetElementSpaceSize
());
p_b_global
,
b_k
0
_n_
k1_
global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_m0_m1_m2_n_global_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
const
auto
K
0
=
a_k
0
_m_
k1_
global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k
0
_m_
k1_
global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k
0
_n_
k1_
global_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -204,74 +205,73 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{});
// Number<MPerThread>{},
// Number<NPerThread>{});
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_KPack
>
{},
Number
<
BBlockTransferDstScalarPerVector_KPack
>
{},
Number
<
KPack
>
{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}),
max_lds_align
);
constexpr
auto
a_k
0
_m_
k1_
block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}
,
Number
<
KPack
>
{}
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
constexpr
auto
b_k
0
_n_
k1_
block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}
,
Number
<
KPack
>
{}
),
max_lds_align
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
Sequence
<
KPerBlock
,
MPerBlock
,
KPack
>
,
ABlockTransferThreadSliceLengths_K_M
_KPack
,
ABlockTransferThreadClusterLengths_K_M
_KPack
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k
0
_m_
k1_
global_desc
),
decltype
(
a_k
0
_m_
k1_
block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_
M
,
Sequence
<
0
,
1
,
2
>
,
2
,
//
ABlockTransferSrcVectorDim,
2
,
1
,
//
ABlockTransferSrcScalarPerVector,
1
,
//
ABlockTransferDstScalarPerVector_
KPack
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k_m_global_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_global
),
a_k_m_block_desc
,
make_multi_index
(
0
,
0
));
a_k
0
_m_
k1_
global_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_global
,
0
),
a_k
0
_m_
k1_
block_desc
,
make_multi_index
(
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
>
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
Sequence
<
KPerBlock
,
NPerBlock
,
KPack
>
,
BBlockTransferThreadSliceLengths_K_N
_KPack
,
BBlockTransferThreadClusterLengths_K_N
_KPack
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k
0
_n_
k1_
global_desc
),
decltype
(
b_k
0
_n_
k1_
block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
BBlockTransferSrcVectorDim
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_
N
,
Sequence
<
0
,
1
,
2
>
,
1
,
//
BBlockTransferSrcVectorDim,
2
,
1
,
//
BBlockTransferSrcScalarPerVector,
1
,
//
BBlockTransferDstScalarPerVector_
KPack
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k_n_global_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_global
),
b_k_n_block_desc
,
make_multi_index
(
0
,
0
));
b_k
0
_n_
k1_
global_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_global
,
0
),
b_k
0
_n_
k1_
block_desc
,
make_multi_index
(
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -285,25 +285,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
NPerBlock
%
(
NPerWave
*
NRepeat
)
==
0
,
"wrong!"
);
static_assert
(
KPerBlock
%
KPack
==
0
,
"KPerBlock is wrong!"
);
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
transform_dynamic_tensor_descriptor
(
a_k_m_block_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
KPerBlock
/
KPack
>
{},
Number
<
KPack
>
{})),
make_unmerge_transform
(
make_
tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerBlock
/
MRepeat
>
{}))
)
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
>
{}));
a_k
0
_m_
k1_
block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerBlock
/
MRepeat
>
{})),
make_
pass_through_transform
(
Number
<
KPack
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}
,
Sequence
<
2
>
{}
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}
,
Sequence
<
3
>
{}
));
constexpr
auto
b_k0_n0_n1_k1_block_desc
=
transform_dynamic_tensor_descriptor
(
b_k_n_block_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
KPerBlock
/
KPack
>
{},
Number
<
KPack
>
{})),
make_unmerge_transform
(
make_
tuple
(
Number
<
NRepeat
>
{},
Number
<
NPerBlock
/
NRepeat
>
{}))
)
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
>
{}));
b_k
0
_n_
k1_
block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NPerBlock
/
NRepeat
>
{})),
make_
pass_through_transform
(
Number
<
KPack
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}
,
Sequence
<
2
>
{}
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}
,
Sequence
<
3
>
{}
));
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
<
BlockSize
,
...
...
@@ -313,6 +311,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
MPerWave
,
NPerWave
,
KPack
>
{};
constexpr
auto
CLayout
=
blockwise_gemm
.
GetCLayout
();
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
...
...
@@ -332,10 +331,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k
0
_m_
k1_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_k
0
_n_
k1_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
...
...
@@ -349,37 +348,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_k_n_global_iterator_hacks
=
BGlobalIteratorHacks
{};
constexpr
auto
a_k
0
_m_
k1_
global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_k
0
_n_
k1_
global_iterator_hacks
=
BGlobalIteratorHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
constexpr
auto
a_k
0
_m_
k1_
global_move_slice_window_iterator_hack
=
AGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
constexpr
auto
b_k
0
_n_
k1_
global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
p_a_block_double
,
a_k
0
_m_
k1_
block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
,
b_k_n_block_desc
.
GetElementSpaceSize
());
p_b_block_double
,
b_k
0
_n_
k1_
block_desc
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
+
a_block_space_size
,
a_k_m_block_desc
.
GetElementSpaceSize
());
p_a_block_double
+
a_block_space_size
,
a_k
0
_m_
k1_
block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
+
b_block_space_size
,
b_k_n_block_desc
.
GetElementSpaceSize
());
p_b_block_double
+
b_block_space_size
,
b_k
0
_n_
k1_
block_desc
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_global_desc
,
a_global_buf
,
a_k0_m_k1_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_global_desc
,
b_global_buf
,
b_k0_n_k1_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k
0
_m_
k1_
block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k
0
_n_
k1_
block_desc
,
b_block_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
...
...
@@ -391,77 +392,83 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_global_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_global_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
a_k
0
_m_
k1_
global_desc
,
a_global_buf
,
a_k
0
_m_
k1_
global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
b_k
0
_n_
k1_
global_desc
,
b_global_buf
,
b_k
0
_n_
k1_
global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd_buf
);
a_blockwise_copy
.
RunWrite
(
a_k
0
_m_
k1_
block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k
0
_n_
k1_
block_desc
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_global_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_global_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
a_k
0
_m_
k1_
global_desc
,
a_global_buf
,
a_k
0
_m_
k1_
global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
b_k
0
_n_
k1_
global_desc
,
b_global_buf
,
b_k
0
_n_
k1_
global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k
0
_m_
k1_
block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k
0
_n_
k1_
block_desc
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
}
while
(
k_block_data_begin
<
K
0
-
2
*
KPerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k
0
_m_
k1_
global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
a_k
0
_m_
k1_
global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k
0
_n_
k1_
global_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
b_k
0
_n_
k1_
global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_global_desc
,
a_global_buf
,
a_k0_m_k1_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_global_desc
,
b_global_buf
,
b_k0_n_k1_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd_buf
);
a_blockwise_copy
.
RunWrite
(
a_k
0
_m_
k1_
block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k
0
_n_
k1_
block_desc
,
b_block_odd_buf
);
__syncthreads
();
...
...
@@ -507,10 +514,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
mr_i
,
nr_i
,
xdlops_i
,
blk_i
);
const
index_t
k
_thread_data_on_global
=
const
index_t
m
_thread_data_on_global
=
m_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I0
];
const
index_t
b
_thread_data_on_global
=
const
index_t
n
_thread_data_on_global
=
n_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_m1_m2_n_global_tensor_iterator_hacks
=
...
...
@@ -528,10 +535,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m1_m2_n_global_desc
,
make_multi_index
(
k
_thread_data_on_global
/
(
M2
*
M1
),
k
_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
k
_thread_data_on_global
%
M2
,
b
_thread_data_on_global
)}
make_multi_index
(
m
_thread_data_on_global
/
(
M2
*
M1
),
m
_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m
_thread_data_on_global
%
M2
,
n
_thread_data_on_global
)}
.
Run
(
c_m0_m1_m2_n_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_blk_buf_
,
...
...
@@ -549,8 +556,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
AGlobalDesc
&
a_k
0
_m_
k1_
global_desc
,
const
BGlobalDesc
&
b_k
0
_n_
k1_
global_desc
,
const
CGlobalDesc
&
c_m0_m1_m2_n_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
...
...
@@ -563,8 +570,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
a_k
0
_m_
k1_
global_desc
,
b_k
0
_n_
k1_
global_desc
,
c_m0_m1_m2_n_global_desc
,
c_block_cluster_desc
,
p_shared_block
,
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
95710403
...
...
@@ -73,64 +73,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
OutDesc
::
GetLengths
()));
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#if 0
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK
0
_GemmM
_GemmK1
=
Sequence
<
1
,
2
,
GemmKPack
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK
0
_GemmM
_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_
GemmM
=
1
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_
KPack
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK
0
_GemmN
_GemmK1
=
Sequence
<
1
,
2
,
GemmKPack
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK
0
_GemmN
_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_
GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_
KPack
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#endif
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
<
TInWei
,
...
...
@@ -167,21 +141,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmKPack
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferThreadSliceLengths_GemmK
0
_GemmM
_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK
0
_GemmM
_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_
GemmM
,
GemmABlockTransferDstScalarPerVector_
KPack
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
GemmBBlockTransferThreadSliceLengths_GemmK
0
_GemmN
_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK
0
_GemmN
_GemmK1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_
GemmN
,
GemmBBlockTransferDstScalarPerVector_
KPack
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment