Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9a383af9
Commit
9a383af9
authored
May 29, 2021
by
Chao Liu
Browse files
overhauling fwd-v4r4
parent
e9956403
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
2196 additions
and
2 deletions
+2196
-2
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
...osable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
+387
-0
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...m_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+165
-0
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
...e_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
+514
-0
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+552
-0
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+559
-0
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+18
-1
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
0 → 100644
View file @
9a383af9
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1R2
#define CK_DRIVER_DYNAMIC_GEMM_V1R2
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v1r2.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
index_t
M1N1ThreadClusterM10
,
index_t
M1N1ThreadClusterN10
,
index_t
M1N1ThreadClusterM11
,
index_t
M1N1ThreadClusterN11
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_M
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_N
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobalIteratorHacks
,
typename
BGlobalIteratorHacks
,
typename
CGlobalIteratorHacks
,
typename
AGlobalMoveSliceWindowIteratorHacks
,
typename
BGlobalMoveSliceWindowIteratorHacks
>
__host__
float
launch_kernel_dynamic_gemm_v1r2
(
const
FloatAB
*
p_a_global
,
const
FloatAB
*
p_b_global
,
FloatC
*
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
AGlobalIteratorHacks
,
BGlobalIteratorHacks
,
CGlobalIteratorHacks
,
AGlobalMoveSliceWindowIteratorHacks
,
BGlobalMoveSliceWindowIteratorHacks
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
auto
M1
=
Number
<
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
>
{};
constexpr
auto
N1
=
Number
<
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
>
{};
if
(
!
(
MPerBlock
%
M1
==
0
&&
NPerBlock
%
N1
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
// GEMM
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
AGlobalDesc
,
BGlobalDesc
,
CGlobalDesc
,
CBlockClusterDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGlobalIteratorHacks
,
BGlobalIteratorHacks
,
CGlobalIteratorHacks
,
AGlobalMoveSliceWindowIteratorHacks
,
BGlobalMoveSliceWindowIteratorHacks
>
;
const
auto
GridSize
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
bool
has_main_k_block_loop
=
(
K
+
KPerBlock
)
/
(
2
*
KPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
else
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k_m_global_desc_device_buf
(
sizeof
(
AGlobalDesc
));
DeviceMem
b_k_n_global_desc_device_buf
(
sizeof
(
BGlobalDesc
));
DeviceMem
c_m0_m1_n0_n1_global_desc_device_buf
(
sizeof
(
CGlobalDesc
));
DeviceMem
c_block_cluster_desc_device_buf
(
sizeof
(
c_block_cluster_desc
));
a_k_m_global_desc_device_buf
.
ToDevice
(
&
a_k_m_global_desc
);
b_k_n_global_desc_device_buf
.
ToDevice
(
&
b_k_n_global_desc
);
c_m0_m1_n0_n1_global_desc_device_buf
.
ToDevice
(
&
c_m0_m1_n0_n1_global_desc
);
c_block_cluster_desc_device_buf
.
ToDevice
(
&
c_block_cluster_desc
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
return
ave_time
;
#endif
}
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
9a383af9
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmM1
,
index_t
GemmN1
,
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{};
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{};
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// out_gemm_block_cluster_desc
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr
auto
wei_gemmk_gemmm_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr
auto
in_gemmk_gemmn_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemm_block_cluster_desc
,
wei_gemmk_gemmm_global_iterator_hacks
,
in_gemmk_gemmn_global_iterator_hacks
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks
,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
}
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
0 → 100644
View file @
9a383af9
#ifndef CK_BLOCKWISE_GEMM_V2_HPP
#define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp"
namespace
ck
{
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
CThreadDesc
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
index_t
M1N1ThreadClusterM10
,
index_t
M1N1ThreadClusterN10
,
index_t
M1N1ThreadClusterM11
,
index_t
M1N1ThreadClusterN11
,
index_t
AThreadCopyScalarPerVector_M1
,
index_t
BThreadCopyScalarPerVector_N1
,
typename
std
::
enable_if
<
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
public:
__device__
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
])}
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
,
"wrong! blocksize and cluster size not consistent"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
}
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
index_t
thread_id
)
{
constexpr
index_t
M0
=
ABlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
M1
=
ABlockDesc
{}.
GetLength
(
I2
);
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
// 4-d data space into 4-d thread space
// upper: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, M1N1ThreadClusterN10 *
// M1N1ThreadClusterN11} lower: {M0, M1, N0, N1}
constexpr
auto
adaptor0
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_vectorize_transform
(
M0
,
1
),
make_vectorize_transform
(
M1PerThread
,
M1
/
M1PerThread
),
make_vectorize_transform
(
N0
,
1
),
make_vectorize_transform
(
N1PerThread
,
N1
/
N1PerThread
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// thread position 4-d thread space
// upper: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
// M1N1ThreadClusterN11} lower: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1,
// M1N1ThreadClusterN10 * M1N1ThreadClusterN11}
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_freeze_transform
(
make_multi_index
(
0
)),
make_unmerge_transform
(
make_tuple
(
M1N1ThreadClusterM10
,
M1N1ThreadClusterM11
)),
make_freeze_transform
(
make_multi_index
(
0
)),
make_unmerge_transform
(
make_tuple
(
M1N1ThreadClusterN10
,
M1N1ThreadClusterN11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
1
>
{},
Sequence
<>
{},
Sequence
<
2
,
3
>
{}));
// 4-d thread space to 1-d thread space
// upper: {BlockSize}
// lower: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
// M1N1ThreadClusterN11}
constexpr
auto
adaptor2
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
))),
make_tuple
(
Sequence
<
0
,
2
,
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
constexpr
auto
cluster_desc
=
chain_tensor_adaptors
(
adaptor0
,
adaptor1
,
adaptor2
);
return
cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
CThreadDesc
,
Sequence
<
KPerThread
>
,
Sequence
<
M0_
,
M1PerThread
>
,
Sequence
<
N0_
,
N1PerThread
>>
{};
constexpr
index_t
K
=
ABlockDesc
{}.
GetLength
(
I0
);
static_for
<
0
,
K
,
KPerThread
>
{}([
&
](
auto
k
)
{
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
});
}
private:
static
constexpr
index_t
M0_
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0_
=
BBlockDesc
{}.
GetLength
(
I1
);
// A[K, M0, M1]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0_
>
{},
Number
<
M1PerThread
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0_
>
{},
Number
<
N1PerThread
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerThread
,
M0_
,
M1PerThread
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M1
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerThread
,
N0_
,
N1PerThread
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N1
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
CThreadDesc
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
index_t
M1N1ThreadClusterM10
,
index_t
M1N1ThreadClusterN10
,
index_t
M1N1ThreadClusterM11
,
index_t
M1N1ThreadClusterN11
,
index_t
AThreadCopyScalarPerVector_M1
,
index_t
BThreadCopyScalarPerVector_N1
,
typename
std
::
enable_if
<
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
public:
__device__
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
])}
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
,
"wrong! blocksize and cluster size not consistent"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// TODO: remove this restriction
static_assert
(
ABlockDesc
{}.
GetLength
(
I1
)
==
2
&&
BBlockDesc
{}.
GetLength
(
I1
)
==
2
&&
CThreadDesc
{}.
GetLength
(
I0
)
==
2
&&
CThreadDesc
{}.
GetLength
(
I2
)
==
2
,
"wrong"
);
}
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
index_t
thread_id
)
{
constexpr
index_t
M0
=
ABlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
M1
=
ABlockDesc
{}.
GetLength
(
I2
);
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
// 4-d data space into 4-d thread space
constexpr
auto
adaptor0
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_vectorize_transform
(
M0
,
1
),
make_vectorize_transform
(
M1PerThread
,
M1
/
M1PerThread
),
make_vectorize_transform
(
N0
,
1
),
make_vectorize_transform
(
N1PerThread
,
N1
/
N1PerThread
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// thread position 4-d thread space
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_freeze_transform
(
make_multi_index
(
0
)),
make_unmerge_transform
(
make_tuple
(
M1N1ThreadClusterM10
,
M1N1ThreadClusterM11
)),
make_freeze_transform
(
make_multi_index
(
0
)),
make_unmerge_transform
(
make_tuple
(
M1N1ThreadClusterN10
,
M1N1ThreadClusterN11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
1
>
{},
Sequence
<>
{},
Sequence
<
2
,
3
>
{}));
// 4-d thread space to 1-d thread space
constexpr
auto
adaptor2
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
))),
make_tuple
(
Sequence
<
0
,
2
,
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
constexpr
auto
cluster_desc
=
chain_tensor_adaptors
(
adaptor0
,
adaptor1
,
adaptor2
);
return
cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
CThreadDesc
,
Sequence
<
KPerThread
>
,
Sequence
<
1
,
M1PerThread
>
,
Sequence
<
1
,
N1PerThread
>>
{};
constexpr
index_t
K
=
ABlockDesc
{}.
GetLength
(
I0
);
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
// loop over rest of k
static_for
<
KPerThread
,
K
,
KPerThread
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
}
private:
static
constexpr
index_t
M0_
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0_
=
BBlockDesc
{}.
GetLength
(
I1
);
// A[K, M0, M1]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0_
>
{},
Number
<
M1PerThread
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0_
>
{},
Number
<
N1PerThread
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThread
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M1
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThread
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N1
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
0 → 100644
View file @
9a383af9
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_v1r2
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
a_k_m_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
{
GridwiseGemm
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_v1r2
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_k_m_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
((
const
void
*
)
p_a_k_m_global_desc
);
const
auto
b_k_n_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
((
const
void
*
)
p_b_k_n_global_desc
);
const
auto
c_m0_m1_n0_n1_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
((
const
void
*
)
p_c_m0_m1_n0_n1_global_desc
);
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
GridwiseGemm
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
#endif
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
index_t
M1N1ThreadClusterM10
,
index_t
M1N1ThreadClusterN10
,
index_t
M1N1ThreadClusterM11
,
index_t
M1N1ThreadClusterN11
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_M
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_N
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobalIteratorHacks
,
typename
BGlobalIteratorHacks
,
typename
CGlobalIteratorHacks
,
typename
AGlobalMoveSliceWindowIteratorHacks
,
typename
BGlobalMoveSliceWindowIteratorHacks
>
struct
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
M1PerThread
>
{},
Number
<
N1PerThread
>
{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
FloatAB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
,
a_k_m_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_k_n_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_m0_m1_n0_n1_global_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_block_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const
index_t
m_block_data_idx_on_global
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_global
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
M1PerThread
>
{},
Number
<
N1PerThread
>
{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k_m_global_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_global
),
a_k_m_block_desc
,
make_multi_index
(
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
>
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
BBlockTransferSrcVectorDim
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k_n_global_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_global
),
b_k_n_block_desc
,
make_multi_index
(
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert
(
MPerBlock
%
(
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
)
==
0
&&
NPerBlock
%
(
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
)
==
0
,
"wrong!"
);
constexpr
index_t
M0PerThread
=
MPerBlock
/
(
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
);
constexpr
index_t
N0PerThread
=
NPerBlock
/
(
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
);
constexpr
auto
a_k_m0_m1_block_desc
=
transform_dynamic_tensor_descriptor
(
a_k_m_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0PerThread
>
{},
Number
<
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
b_k_n0_n1_block_desc
=
transform_dynamic_tensor_descriptor
(
b_k_n_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0PerThread
>
{},
Number
<
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
M0PerThread
>
{},
Number
<
M1PerThread
>
{},
Number
<
N0PerThread
>
{},
Number
<
N1PerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k_m0_m1_block_desc
),
decltype
(
b_k_n0_n1_block_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
,
M1PerThread
,
N1PerThread
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
Sequence
<
M0PerThread
,
M1PerThread
,
N0PerThread
,
N1PerThread
>>
{}
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_k_n_global_iterator_hacks
=
BGlobalIteratorHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
AGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
,
b_k_n_block_desc
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_a_block_double
+
a_block_space_size
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
+
b_block_space_size
,
b_k_n_block_desc
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
{
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
,
b_k_n_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd_buf
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
M1
=
Number
<
M1PerThread
*
M1N1ThreadClusterM10
*
M1N1ThreadClusterM11
>
{};
constexpr
auto
N1
=
Number
<
N1PerThread
*
M1N1ThreadClusterN10
*
M1N1ThreadClusterN11
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
const
auto
c_thread_data_idx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
());
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
Sequence
<
M0PerThread
,
M1PerThread
,
N0PerThread
,
N1PerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m1_n0_n1_global_desc
,
make_multi_index
(
m_block_data_idx_on_global
/
M1
+
c_thread_data_idx_on_block
[
I0
],
c_thread_data_idx_on_block
[
I1
],
n_block_data_idx_on_global
/
N1
+
c_thread_data_idx_on_block
[
I2
],
c_thread_data_idx_on_block
[
I3
])}
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_m0_m1_n0_n1_global_desc
,
c_global_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
};
}
// namespace ck
#endif
composable_kernel/include/utility/config.amd.hpp.in
View file @
9a383af9
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#define CK_DEVICE_BACKEND_AMD 1
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
// GPU ID
#if
1
#if
0
#define CK_AMD_GPU_GFX906 1
#define CK_AMD_GPU_GFX906 1
#elif 0
#elif 0
#define CK_AMD_GPU_GFX908 1
#define CK_AMD_GPU_GFX908 1
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
9a383af9
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_gemm_v1r2.hpp"
template
<
class
TInWei
,
ck
::
index_t
InWeiVectorSize
,
class
TAcc
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
TInWei
>&
in_n_c_hi_wi
,
WeiDesc
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
OutDesc
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
#if 1
// run-time variables
const
auto
in_n_c_hi_wi_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
to_multi_index
(
InDesc
::
GetLengths
()));
const
auto
wei_k_c_y_x_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
to_multi_index
(
WeiDesc
::
GetLengths
()));
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
to_multi_index
(
OutDesc
::
GetLengths
()));
const
auto
conv_strides
=
to_multi_index
(
ConvStrides
{});
const
auto
conv_dilations
=
to_multi_index
(
ConvDilations
{});
const
auto
in_left_pads
=
to_multi_index
(
InLeftPads
{});
const
auto
in_right_pads
=
to_multi_index
(
InRightPads
{});
#else
// compile-time variables
const
auto
in_n_c_hi_wi_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
InDesc
::
GetLengths
()));
const
auto
wei_k_c_y_x_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
WeiDesc
::
GetLengths
()));
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
OutDesc
::
GetLengths
()));
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#if 0
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif
0
// cdata = 32, BlockSize 64, 16x128x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 64, BlockSize 64, 16x256x2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 64, BlockSize 64, 16x256x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
2
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
2
;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 2x2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#endif
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
<
GemmMPerBlock
,
GemmNPerBlock
,
GemmM1
,
GemmN1
>
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
launch_kernel_dynamic_gemm_v1r2
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TOut
,
InMemoryDataOperation
::
Set
,
decltype
(
descs
[
I0
]),
decltype
(
descs
[
I1
]),
decltype
(
descs
[
I2
]),
decltype
(
descs
[
I3
]),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
descs
[
I4
]),
decltype
(
descs
[
I5
]),
decltype
(
descs
[
I6
]),
decltype
(
descs
[
I7
]),
decltype
(
descs
[
I8
])
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
descs
[
I3
],
descs
[
I4
],
descs
[
I5
],
descs
[
I6
],
descs
[
I7
],
descs
[
I8
],
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
out_n_k_ho_wo_device_buf
.
FromDevice
(
out_n_k_ho_wo
.
mData
.
data
());
}
driver/src/conv_driver.cpp
View file @
9a383af9
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
...
@@ -724,7 +725,7 @@ int main(int argc, char* argv[])
...
@@ -724,7 +725,7 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
<
in_data_t
,
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
<
in_data_t
,
in_vector_size
,
in_vector_size
,
acc_data_t
,
acc_data_t
,
...
@@ -740,6 +741,22 @@ int main(int argc, char* argv[])
...
@@ -740,6 +741,22 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw
<
in_data_t
,
in_vector_size
,
acc_data_t
,
out_data_t
>
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 1
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
<
in_data_t
,
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
<
in_data_t
,
in_vector_size
,
in_vector_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment