Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b6ece3c6
Commit
b6ece3c6
authored
Jun 16, 2022
by
wangshaojie6
Browse files
use AK1/BK1
parent
78690467
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
143 additions
and
387 deletions
+143
-387
example/01_gemm/gemm_xdl_fp16_splitk.cpp
example/01_gemm/gemm_xdl_fp16_splitk.cpp
+6
-13
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+74
-73
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+63
-301
No files found.
example/01_gemm/gemm_xdl_fp16_splitk.cpp
View file @
b6ece3c6
...
...
@@ -46,19 +46,12 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
16
,
128
,
4
,
8
,
16
,
16
,
1
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
2
,
true
,
S
<
1
,
4
,
32
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
;
//< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 2>;
//< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 16>, 2>;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4>;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 2>;
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
4
,
256
,
16
,
128
,
32
,
8
,
2
,
16
,
16
,
1
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
4
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
b6ece3c6
...
...
@@ -21,44 +21,48 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLds
Add
ExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_
A
K1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
>
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
...
...
@@ -67,14 +71,12 @@ struct DeviceGemmXdlSplitKCShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
int
KBatch
,
int
KPad
)
{
assert
(
KPad
%
(
K1
*
KBatch
)
==
0
);
assert
(
KPad
%
(
A
K1
*
KBatch
)
==
0
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
const
index_t
A
K0
=
KPad
/
(
A
K1
*
KBatch
);
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
...
...
@@ -98,7 +100,7 @@ struct DeviceGemmXdlSplitKCShuffle
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
A
K0
,
A
K1
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -107,7 +109,7 @@ struct DeviceGemmXdlSplitKCShuffle
{
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
A
K0
,
A
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -117,9 +119,9 @@ struct DeviceGemmXdlSplitKCShuffle
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
int
KBatch
,
int
KPad
)
{
assert
(
KPad
%
(
K1
*
KBatch
)
==
0
);
assert
(
KPad
%
(
B
K1
*
KBatch
)
==
0
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
const
index_t
B
K0
=
KPad
/
(
B
K1
*
KBatch
);
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -143,7 +145,7 @@ struct DeviceGemmXdlSplitKCShuffle
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
B
K0
,
B
K1
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -152,7 +154,7 @@ struct DeviceGemmXdlSplitKCShuffle
{
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
B
K0
,
B
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -196,8 +198,7 @@ struct DeviceGemmXdlSplitKCShuffle
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
const
index_t
KPad
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
*
KBatch
)
*
(
KPerBlock
*
KBatch
);
return
KPad
;
}
...
...
@@ -209,7 +210,7 @@ struct DeviceGemmXdlSplitKCShuffle
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
Gemm
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
...
...
@@ -218,42 +219,42 @@ struct DeviceGemmXdlSplitKCShuffle
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
NumGemmKPrefetchStage
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_
A
K0_M_
A
K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_
A
K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLds
Add
ExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_
B
K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
false
,
3
>
;
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
// GridwiseGemm
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
Gemm
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
...
...
@@ -262,36 +263,36 @@ struct DeviceGemmXdlSplitKCShuffle
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
NumGemmKPrefetchStage
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_
A
K0_M_
A
K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_
A
K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLds
Add
ExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_
B
K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
false
,
3
>
;
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
...
...
@@ -412,9 +413,9 @@ struct DeviceGemmXdlSplitKCShuffle
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K
0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
auto
K
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
*
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K
0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K
);
float
ave_time
=
0
;
...
...
@@ -634,7 +635,7 @@ struct DeviceGemmXdlSplitKCShuffle
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
KPerBlock
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
b6ece3c6
...
...
@@ -14,127 +14,6 @@
namespace
ck
{
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template
<
typename
LowLengths
>
struct
Merge_v4_no_carry
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v4_no_carry
()
=
default
;
__host__
__device__
constexpr
Merge_v4_no_carry
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// division and mod
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_up_diff
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
INm1
=
Number
<
NDimLow
-
1
>
{};
index_t
tmp
=
idx_up_new
[
I0
];
idx_low
(
INm1
)
=
tmp
;
idx_diff_low
(
INm1
)
=
idx_up_diff
[
I0
];
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v3_direct_division_mod_wrw, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v4_no_carry
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v4_no_carry
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
@@ -204,12 +83,14 @@ template <index_t BlockSize,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
NumGemmKPrefetchStage
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
...
@@ -219,7 +100,7 @@ template <index_t BlockSize,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
...
...
@@ -227,13 +108,11 @@ template <index_t BlockSize,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
BBlockLdsExtraN1
=
false
,
index_t
NumGemmKPrefetchStage
=
4
>
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -246,11 +125,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// N0 N1
static
constexpr
auto
N1PerBlock
=
Number
<
128
/
(
sizeof
(
FloatAB
)
*
K1
)
>
{};
static
constexpr
auto
N0PerBlock
=
Number
<
NPerBlock
/
N1PerBlock
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -260,129 +138,48 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v2
<
NumGemmKPrefetchStage
>
;
#endif
__host__
__device__
static
constexpr
auto
Get
B
BlockDescriptor_K0PerBlock_
N
PerBlock_K1
()
__host__
__device__
static
constexpr
auto
Get
A
BlockDescriptor_
A
K0PerBlock_
M
PerBlock_
A
K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
if
constexpr
(
BBlockLdsExtraN1
)
{
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
K1
),
Number
<
N1PerBlock
>
{}
*
K1
+
K1
,
K1
,
I1
));
constexpr
auto
b_block_desc_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_block_desc_k0_n_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_k0_n_k1
;
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
Get
B
BlockDescriptor_Batch_K0PerBlock_
N
PerBlock_K1
()
__host__
__device__
static
constexpr
auto
Get
A
BlockDescriptor_
K
Batch_
A
K0PerBlock_
M
PerBlock_
A
K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0
*
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_b_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
if
constexpr
(
BBlockLdsExtraN1
)
{
constexpr
auto
b_block_desc_b_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
K1
),
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
K1
),
Number
<
N1PerBlock
>
{}
*
K1
+
K1
,
K1
,
I1
));
constexpr
auto
b_block_desc_b_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_b_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v4_no_carry
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_b_k0_n_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_b_k0_n_k1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0
*
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_k0_m_k1_block_desc
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_k0_n_k1_block_desc
=
GetBBlockDescriptor_
B
K0PerBlock_NPerBlock_
B
K1
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -405,34 +202,26 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
*
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
);
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
0
/
K
0
PerBlock
;
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
if
(
!
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
&&
K1
==
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
&&
K1
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
&&
KBatch
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)))
if
(
!
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
0
%
K
0
PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
...
...
@@ -444,14 +233,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K
0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K
)
{
// const bool has_main_k0_block_loop = K0 > K0PerBlock;
const
index_t
num_loop
=
K0
/
K0PerBlock
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
// return has_main_k0_block_loop;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -517,8 +303,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
@@ -541,51 +325,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
)
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_b_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_k0_m_k1_block_desc
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
a_b_k0_m_k1_block_desc
=
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_k0_n_k1_block_desc
=
GetBBlockDescriptor_
B
K0PerBlock_NPerBlock_
B
K1
();
constexpr
auto
b_b_k0_n_k1_block_desc
=
GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_b_k0_n_k1_block_desc
=
GetBBlockDescriptor_
K
Batch_
B
K0PerBlock_NPerBlock_
B
K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0
PerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
1
,
A
K0
,
MPerBlock
,
A
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -616,7 +373,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0
PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
1
,
B
K0
,
NPerBlock
,
B
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -649,6 +406,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXDL
,
NPerXDL
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
...
...
@@ -659,7 +419,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
NPerXDL
,
MRepeat
,
NRepeat
,
K
1
>
{};
K
Pack
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -670,8 +430,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K
0
PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K
0
PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
BK1
,
0
,
0
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
...
...
@@ -679,7 +439,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
// gridwise GEMM pipeline
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
(
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
*
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
))
/
KPerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_block_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment