Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
49b926b6
Commit
49b926b6
authored
Jun 02, 2021
by
Chao Liu
Browse files
overhauling fwd-v4r4
parent
b8382727
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
211 additions
and
574 deletions
+211
-574
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
...osable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
+49
-66
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+108
-104
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+1
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+53
-403
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
View file @
49b926b6
...
...
@@ -26,21 +26,21 @@ template <index_t BlockSize,
index_t
M1N1ThreadClusterN10
,
index_t
M1N1ThreadClusterM11
,
index_t
M1N1ThreadClusterN11
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadSliceLengths_K_M
0_M1
,
typename
ABlockTransferThreadClusterLengths_K_M
0_M1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_M
,
index_t
ABlockTransferDstScalarPerVector_M
1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadSliceLengths_K_N
0_N1
,
typename
BBlockTransferThreadClusterLengths_K_N
0_N1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_N
,
index_t
BBlockTransferDstScalarPerVector_N
1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
...
...
@@ -69,48 +69,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
// GEMM
using
GridwiseGemm
=
GridwiseDynamicGemm_km_kn_m
0m1n0n1
_v1r2
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
AKMGridDesc
,
BKNGridDesc
,
CMNGridDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGridIteratorHacks
,
BGridIteratorHacks
,
CGridIteratorHacks
,
AGridMoveSliceWindowIteratorHacks
,
BGridMoveSliceWindowIteratorHacks
>
;
GridwiseDynamicGemm_km_kn_m
n
_v1r2
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
AKMGridDesc
,
BKNGridDesc
,
CMNGridDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
,
ABlockTransferThreadSliceLengths_K_M
0_M1
,
ABlockTransferThreadClusterLengths_K_M
0_M1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K_N
0_N1
,
BBlockTransferThreadClusterLengths_K_N
0_N1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGridIteratorHacks
,
BGridIteratorHacks
,
CGridIteratorHacks
,
AGridMoveSliceWindowIteratorHacks
,
BGridMoveSliceWindowIteratorHacks
>
;
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
...
...
@@ -118,8 +118,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m_n_grid_desc
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting"
);
}
const
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
...
...
@@ -154,8 +153,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
...
...
@@ -172,8 +169,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
...
@@ -185,8 +180,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
...
...
@@ -203,8 +196,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
...
@@ -216,8 +207,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
...
...
@@ -234,8 +223,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
...
@@ -247,8 +234,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
...
...
@@ -265,8 +250,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
49b926b6
...
...
@@ -15,8 +15,6 @@ namespace ck {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AKMGridDesc
,
typename
BKNGridDesc
,
typename
AKM0M1GridDesc
,
typename
BKN0N1GridDesc
,
typename
CM0M10M11N0N10N11GridDesc
,
...
...
@@ -31,8 +29,6 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AKMGridDesc
a_k_m_grid_desc
,
const
BKNGridDesc
b_k_n_grid_desc
,
const
AKM0M1GridDesc
a_k_m0_m1_grid_desc
,
const
BKN0N1GridDesc
b_k_n0_n1_grid_desc
,
const
CM0M10M11N0N10N11GridDesc
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
...
@@ -47,8 +43,6 @@ __global__ void
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
...
@@ -99,7 +93,7 @@ template <index_t BlockSize,
typename
CGridIteratorHacks
,
typename
AGridMoveSliceWindowIteratorHacks
,
typename
BGridMoveSliceWindowIteratorHacks
>
struct
GridwiseDynamicGemm_km_kn_m
0m1n0n1
_v1r2
struct
GridwiseDynamicGemm_km_kn_m
n
_v1r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -124,13 +118,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_
aligned_
space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
constexpr
auto
b_block_
aligned_
space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
return
2
*
(
a_block_
aligned_
space_size
+
b_block_
aligned_
space_size
)
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AKMGridDesc
&
a_k_m_grid_desc
,
...
...
@@ -178,13 +172,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
a_k_m0_m1_
block_clusterized_
grid_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
a_k_m0_m1_grid_desc
=
transform_dynamic_tensor_descriptor
(
a_k_m_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
a_k_m0_m1_
block_clusterized_
grid_desc
;
return
a_k_m0_m1_grid_desc
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -196,13 +190,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
b_k_n0_n1_
block_clusterized_
grid_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
b_k_n0_n1_grid_desc
=
transform_dynamic_tensor_descriptor
(
b_k_n_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
b_k_n0_n1_
block_clusterized_
grid_desc
;
return
b_k_n0_n1_grid_desc
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -246,7 +240,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const
auto
N0
=
N
/
N1
;
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
make_cluster_descriptor_v2
(
make_tuple
(
M0
,
N0
));
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
N0
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
...
...
@@ -263,8 +259,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AKMGridDesc
&
a_k_m_grid_desc
,
const
BKNGridDesc
&
b_k_n_grid_desc
,
const
AKM0M1GridDesc
&
a_k_m0_m1_grid_desc
,
const
BKN0N1GridDesc
&
b_k_n0_n1_grid_desc
,
const
CM0M10M11N0N10N11GridDesc
&
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
...
@@ -273,30 +267,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_grid
,
a_k_m_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_k_m
0_m1
_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_grid
,
b_k_n_grid_desc
.
GetElementSpaceSize
());
p_b_grid
,
b_k_n
0_n1
_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_grid
,
c_m0_m10_m11_n0_n10_n11_grid_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
a_k_m0_m1_grid_desc
.
GetLength
(
I0
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_blockid_to_m0_n0_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
auto
c_m0_n0_block_cluster_idx
=
c_blockid_to_m0_n0_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force index data into SGPR
const
index_t
m_block_work_idx
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
n_block_work_idx
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
const
index_t
m_block_data_idx_on_global
=
__builtin_amdgcn_readfirstlane
(
m_block_work_idx
*
MPerBlock
);
const
index_t
n_block_data_idx_on_global
=
__builtin_amdgcn_readfirstlane
(
n_block_work_idx
*
NPerBlock
);
const
index_t
m0_idx
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
n0_idx
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
...
...
@@ -314,59 +300,67 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m0_m1_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n0_n1_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{}),
max_lds_align
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
>
,
Sequence
<
KPerBlock
,
1
,
MPerBlock
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_k_m_grid_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m
0_m1
_grid_desc
),
decltype
(
a_k_m
0_m1
_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
1
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k_m_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_global
),
a_k_m_block_desc
,
make_multi_index
(
0
,
0
));
true
>
(
a_k_m0_m1_grid_desc
,
make_multi_index
(
0
,
m0_idx
,
0
),
a_k_m0_m1_block_desc
,
make_multi_index
(
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
>
,
Sequence
<
KPerBlock
,
1
,
NPerBlock
>
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_k_n_grid_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n
0_n1
_grid_desc
),
decltype
(
b_k_n
0_n1
_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
1
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k_n_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_global
),
b_k_n_block_desc
,
make_multi_index
(
0
,
0
));
true
>
(
b_k_n0_n1_grid_desc
,
make_multi_index
(
0
,
n0_idx
,
0
),
b_k_n0_n1_block_desc
,
make_multi_index
(
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -398,14 +392,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
sequence_to_tuple_of_number
(
c_m10_n10_m11_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
a_block_
aligned_
space_size
=
math
::
integer_least_multiple
(
a_k_m
0_m1
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_
aligned_
space_size
=
math
::
integer_least_multiple
(
b_k_n
0_n1
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_
aligned_
space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
...
...
@@ -419,37 +413,41 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AGridIteratorHacks
{};
constexpr
auto
b_k_n_global_iterator_hacks
=
BGridIteratorHacks
{};
constexpr
auto
a_k_m
0_m1
_global_iterator_hacks
=
AGridIteratorHacks
{};
constexpr
auto
b_k_n
0_n1
_global_iterator_hacks
=
BGridIteratorHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
constexpr
auto
a_k_m
0_m1
_global_move_slice_window_iterator_hack
=
AGridMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
constexpr
auto
b_k_n
0_n1
_global_move_slice_window_iterator_hack
=
BGridMoveSliceWindowIteratorHacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
p_a_block_double
,
a_k_m
0_m1
_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
,
b_k_n_block_desc
.
GetElementSpaceSize
());
p_b_block_double
,
b_k_n
0_n1
_block_desc
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
+
a_block_space_size
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
+
b_block_space_size
,
b_k_n_block_desc
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_k_n0_n1_block_desc
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k_m_grid_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_grid_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k_m
0_m1
_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n
0_n1
_block_desc
,
b_block_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
...
...
@@ -461,20 +459,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_grid_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_grid_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_grid_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
a_k_m
0_m1
_grid_desc
,
a_global_buf
,
a_k_m
0_m1
_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_grid_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
b_k_n
0_n1
_grid_desc
,
b_global_buf
,
b_k_n
0_n1
_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_n10_m11_n11_thread_desc
,
...
...
@@ -483,32 +483,34 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd_buf
);
a_blockwise_copy
.
RunWrite
(
a_k_m
0_m1
_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n
0_n1
_block_desc
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_grid_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_grid_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_grid_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
a_k_m
0_m1
_grid_desc
,
a_global_buf
,
a_k_m
0_m1
_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_grid_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
b_k_n
0_n1
_grid_desc
,
b_global_buf
,
b_k_n
0_n1
_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_n10_m11_n11_thread_desc
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k_m
0_m1
_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n
0_n1
_block_desc
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
...
...
@@ -517,26 +519,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m
0_m1
_grid_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_grid_desc
,
a_k_m
0_m1
_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n
0_n1
_grid_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
b_k_n
0_n1
_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_grid_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_grid_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
c_m10_n10_m11_n11_thread_desc
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd_buf
);
a_blockwise_copy
.
RunWrite
(
a_k_m
0_m1
_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n
0_n1
_block_desc
,
b_block_odd_buf
);
__syncthreads
();
...
...
@@ -593,10 +597,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m10_m11_n0_n10_n11_grid_desc
,
make_multi_index
(
m
_block_work
_idx
,
make_multi_index
(
m
0
_idx
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
n
_block_work
_idx
,
n
0
_idx
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
])}
.
Run
(
c_m0_m10_m11_n0_n10_n11_thread_desc
,
...
...
composable_kernel/include/utility/math.hpp
View file @
49b926b6
...
...
@@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
template
<
class
X
,
class
Y
>
__host__
__device__
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
{
return
(
x
+
y
-
1
)
/
y
;
return
(
x
+
y
-
Number
<
1
>
{}
)
/
y
;
}
template
<
class
X
,
class
Y
>
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
49b926b6
...
...
@@ -78,302 +78,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#if 0
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif
0
// cdata = 32, BlockSize 64, 16x128x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 64, BlockSize 64, 16x256x2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 64, BlockSize 64, 16x256x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
2
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
2
;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
#if 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -391,82 +96,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 2x2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM0_GemmM1
=
Sequence
<
4
,
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM0_GemmM1
=
Sequence
<
2
,
1
,
128
>
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM1
=
1
;
using
Gemm
A
BlockTransferThreadSliceLengths_GemmK_Gemm
M
=
Sequence
<
4
,
2
>
;
using
Gemm
A
BlockTransferThreadClusterLengths_GemmK_Gemm
M
=
Sequence
<
4
,
64
>
;
using
Gemm
B
BlockTransferThreadSliceLengths_GemmK_Gemm
N0_GemmN1
=
Sequence
<
4
,
1
,
1
>
;
using
Gemm
B
BlockTransferThreadClusterLengths_GemmK_Gemm
N0_GemmN1
=
Sequence
<
2
,
1
,
128
>
;
constexpr
index_t
Gemm
A
BlockTransferSrcScalarPerVector_Gemm
K
=
4
;
constexpr
index_t
Gemm
A
BlockTransferDstScalarPerVector_Gemm
M
=
1
;
constexpr
index_t
Gemm
B
BlockTransferSrcScalarPerVector_Gemm
N1
=
1
;
constexpr
index_t
Gemm
B
BlockTransferDstScalarPerVector_Gemm
N1
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN11
=
1
;
#endif
const
auto
descs
=
...
...
@@ -483,15 +125,23 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk_gemmm_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk_gemmn_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
...
@@ -507,10 +157,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
constexpr
auto
in_gemmk_gemmn
0_gemmn1
_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -533,31 +184,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmNLevel1Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
0_GemmM1
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
0_GemmM1
,
Sequence
<
2
,
1
,
0
>
,
// ABlockTransferThreadClusterArrangeOrder
Sequence
<
2
,
1
,
0
>
,
// ABlockTransferSrcAccessOrder
0
,
// ABlockTransferSrcVectorDim
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmM
1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN0_GemmN1
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN0_GemmN1
,
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
GemmBBlockTransferSrcScalarPerVector_GemmN1
,
GemmBBlockTransferDstScalarPerVector_GemmN1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
wei_gemmk_gemmm_grid_iterator_hacks
),
decltype
(
in_gemmk_gemmn_grid_iterator_hacks
),
GemmCThreadTransferDstScalarPerVector_GemmN1
1
,
decltype
(
wei_gemmk_gemmm
0_gemmn1
_grid_iterator_hacks
),
decltype
(
in_gemmk_gemmn
0_gemmn1
_grid_iterator_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
),
decltype
(
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks
),
decltype
(
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks
)
>
(
decltype
(
wei_gemmk_gemmm
0_gemmm1
_grid_move_slice_window_iterator_hacks
),
decltype
(
in_gemmk_gemmn
0_gemmn1
_grid_move_slice_window_iterator_hacks
)
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
...
...
@@ -566,11 +216,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
wei_gemmk_gemmm_grid_desc
,
in_gemmk_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
wei_gemmk_gemmm_grid_iterator_hacks
,
in_gemmk_gemmn_grid_iterator_hacks
,
wei_gemmk_gemmm
0_gemmn1
_grid_iterator_hacks
,
in_gemmk_gemmn
0_gemmn1
_grid_iterator_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
,
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks
,
wei_gemmk_gemmm
0_gemmm1
_grid_move_slice_window_iterator_hacks
,
in_gemmk_gemmn
0_gemmn1
_grid_move_slice_window_iterator_hacks
,
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment