Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a0a469e4
Commit
a0a469e4
authored
Feb 03, 2023
by
aska-0096
Browse files
save progress
parent
3ddd3578
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
501 additions
and
632 deletions
+501
-632
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+7
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+487
-622
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+1
-1
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+5
-5
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
a0a469e4
...
...
@@ -129,7 +129,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
()
using
Tuple5
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
...
...
@@ -299,8 +303,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()}
;
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()}
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
// block wise level pipe designed for inline asm
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
a0a469e4
...
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle
_v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
a0a469e4
...
...
@@ -18,72 +18,134 @@
namespace
ck
{
template
<
typename
FloatAB
,
typename
Float
GemmAcc
,
typename
Float
CShuffle
,
template
<
typename
GridwiseGemm
,
typename
Float
A
,
typename
Float
B
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_wmma
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b0_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b0_grid_desc_k0_l_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b0_grid
,
p_c_grid
,
p_shared
,
a_grid_desc_k0_m_k1
,
b0_grid_desc_k0_l_k1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b0_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b0_grid_desc_k0_l_k1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx1100__))
}
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB0
,
typename
FloatB1
,
typename
FloatAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_
A
K0_M_
A
K1
,
typename
BGridDesc_
B
K0_
N_B
K1
,
typename
B1GridDesc_
BK
0_N_
BK
1
,
typename
AGridDesc_K0_M_K1
,
typename
B
0
GridDesc_K0_
L_
K1
,
typename
B1GridDesc_
L
0_N_
L
1
,
typename
CGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
Gemm0MPerBlock
,
index_t
Gemm0LPerBlock
,
index_t
Gemm0K0PerBlock
,
index_t
Gemm0K1Value
,
index_t
Gemm0MPerWmma
,
index_t
Gemm0LPerWmma
,
index_t
Gemm0MRepeat
,
index_t
Gemm0LRepeat
,
index_t
Gemm1MPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
B1K1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
index_t
Gemm1L0PerBlock
,
index_t
Gemm1L1Value
,
index_t
Gemm1MPerWmma
,
index_t
Gemm1NPerWmma
,
index_t
Gemm1MRepeat
,
index_t
Gemm1NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_
A
K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_
B
K0_
N_B
K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_
B
K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_
BK
0_N_
BK
1
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
B
0
BlockTransferThreadClusterLengths_K0_
L_
K1
,
typename
B
0
BlockTransferThreadClusterArrangeOrder
,
typename
B
0
BlockTransferSrcAccessOrder
,
index_t
B
0
BlockTransferSrcVectorDim
,
index_t
B
0
BlockTransferSrcScalarPerVector
,
index_t
B
0
BlockTransferDstScalarPerVector_K1
,
bool
B
0
ThreadTransferSrcResetCoordinateAfterRun
,
bool
B
0
BlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_
L
0_N_
L
1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_
BK
1
,
index_t
B1BlockTransferDstScalarPerVector_
L
1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
index_t
B1BlockLdsExtraN
,
index_t
CShuffleM
XdlPerWave
PerShuffle
,
index_t
CShuffleN
XdlPerWave
PerShuffle
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleM
Repeat
PerShuffle
,
index_t
CShuffleN
Repeat
PerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -94,161 +156,127 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
// Gemm1
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
Gemm0K1Value
>
{};
static
constexpr
auto
N1
=
Number
<
Gemm1N1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
)
;
constexpr
auto
max_lds_align
=
K1
;
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
1
,
1
>
(
ABlockDesc_AK0_M_AK1
{});
}
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
Gemm1NWaves
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
Gemm1NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
return
a_block_desc_k0perblock_mperblock_k1
;
}
__host__
__device__
static
constexpr
auto
Get
A
BlockDescriptor_
A
K0PerBlock_
M
PerBlock_
A
K1
()
__host__
__device__
static
constexpr
auto
Get
B
BlockDescriptor_K0PerBlock_
N
PerBlock_K1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
constexpr
auto
max_lds_align
=
K1
;
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
B1K0
,
Number
<
Gemm1NPerBlock
>
{},
B1K1
),
make_tuple
(
Number
<
Gemm1NPerBlock
+
B1BlockLdsExtraN
>
{}
*
B1K1
,
B1K1
,
I1
));
return
b_block_desc_k0perblock_nperblock_k1
;
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
M
XdlPerWave
*
MPer
Xdl
);
constexpr
index_t
NWave
=
Gemm1
NPerBlock
/
(
Gemm1NXdlPerWave
*
NPer
Xdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
M
Repeat
*
MPer
Wmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPer
Wmma
);
constexpr
auto
c_shuffle_block_desc_m
block_mperblock_nblock_nperblock
=
constexpr
auto
c_shuffle_block_desc_m
shrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleM
XdlPerWave
PerShuffle
*
MWave
*
MPer
Xdl
>
{},
Number
<
CShuffleM
Repeat
PerShuffle
*
MWave
*
MPer
Wmma
>
{},
I1
,
Number
<
CShuffleN
XdlPerWave
PerShuffle
*
NWave
*
NPer
Xdl
>
{}));
Number
<
CShuffleN
Repeat
PerShuffle
*
NWave
*
NPer
Wmma
>
{}));
return
c_shuffle_block_desc_m
block_mperblock_nblock_nperblock
;
return
c_shuffle_block_desc_m
shrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
b_block_space_size_aligned
*
sizeof
(
FloatB
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b0_grid_desc_k0_l_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerWmma
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerWmma
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b0_grid_desc_k0_l_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b0_grid_desc_k0_l_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b0_grid_desc_k0_l_k1
.
GetLength
(
I2
)))
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
}
// check gemm0 gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
// check gemm1 gridwise gemm pipeline
if
(
!
(
NPerBlock
%
Gemm1KPerBlock
==
0
))
{
return
false
;
}
const
auto
num_gemm1_k_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inner_loop
))
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
...
...
@@ -264,7 +292,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
K
PerBlock
;
const
index_t
num_loop
=
K
/
(
K0
PerBlock
*
K1
)
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
...
...
@@ -276,12 +304,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
Gemm1
NPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
Gemm1
NPerBlock
>
{}))),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
...
...
@@ -289,284 +317,225 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
/* M01 */
,
index_t
/* N01 */
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1
NPerBlock
,
CGridDesc_M_N
>
(
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}
,
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B1K1
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
B0GridDesc_K0_L_K1
&
b0_grid_desc_k0_l_k1
,
const
B1GridDesc_L0_N_L1
&
b1_grid_desc_l0_n_l1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
B
0
ElementwiseOperation
&
b
0
_element_op
,
const
AccElementwiseOperation
&
acc_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
const
C0MatrixMask
&
c0_matrix_mask
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_
a
k0_m_
a
k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_
b
k0_
n_b
k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b
0
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b
0
_grid
,
b
0
_grid_desc_k0_
l_
k1
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_
bk
0_n_
bk
1
.
GetElementSpaceSize
());
p_b1_grid
,
b1_grid_desc_
l
0_n_
l
1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
{
return
;
}
//
HACK: this f
or
c
e
m/gemm1_n_block_data_idx_on_gri
d into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
0
]
*
M
PerBlock
);
//
St
ore
BlockI
d into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
1
]
*
N
PerBlock
);
const
index_t
gemm1_n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
/*******************************************************************************/
// set up Gemm0
/*******************************************************************************/
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
//
// set up Gemm0
//
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename DstData, */
FloatA
,
/* typename SrcDesc, */
decltype
(
a_grid_desc_k0_m_k1
),
/* typename DstDesc, */
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
/* index_t SrcVectorDim, */
ABlockTransferSrcVectorDim
,
/* index_t DstVectorDim, */
2
,
/* index_t SrcScalarPerVector, */
ABlockTransferSrcScalarPerVector
,
/* index_t DstScalarPerVector, */
ABlockTransferDstScalarPerVector_K1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_
a
k0
_m_a
k1
,
a_block_desc_k0
perblock_mperblock_
k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B
K0
,
NPerBlock
,
B
K1
>
,
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
Sequence
<
K0
PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
Float
A
B
,
Float
A
B
,
decltype
(
b_grid_desc_
b
k0_
n_b
k1
),
decltype
(
b_block_desc_
bk0_n_b
k1
),
FloatB
,
FloatB
,
decltype
(
b
0
_grid_desc_k0_
l_
k1
),
decltype
(
b_block_desc_
k0perblock_nperblock_
k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_
B
K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b0_grid_desc_k0_l_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_
bk0_n_b
k1
,
b_block_desc_
k0perblock_nperblock_
k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
{};
// TransposeC
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
/*******************************************************************************/
// Gemm0
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm0
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
FloatA
,
FloatB
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
{};
// Prepare Register for A*B0 matrix
auto
acc_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
/*******************************************************************************/
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
NPerBlock
,
0
);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
//
// set up Gemm1
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr
auto
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr
auto
acc_thread_desc_k0_m_k1
=
transform_tensor_descriptor
(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr
auto
AccN3
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLength
(
I6
);
constexpr
auto
A1ThreadSlice_K0_M_K1
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
AccN3
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
constexpr
auto
A1ThreadSliceK0
=
A1ThreadSlice_K0_M_K1
[
I0
];
constexpr
auto
A1ThreadSliceM
=
A1ThreadSlice_K0_M_K1
[
I1
];
constexpr
auto
A1ThreadSliceK1
=
A1ThreadSlice_K0_M_K1
[
I2
];
constexpr
auto
a1_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor
(
A1ThreadSlice_K0_M_K1
,
make_tuple
(
A1ThreadSliceM
*
A1ThreadSliceK1
,
A1ThreadSliceK1
,
I1
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
a_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b0_grid_desc_k0_l_k1
.
GetLength
(
I0
),
LPerBlock
,
0
);
/*******************************************************************************/
// softmax
/*******************************************************************************/
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAcc
*>
(
p_shared
),
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
));
// get acc0 8D thread cluster
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
tm0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I0
);
constexpr
auto
tn0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I1
);
constexpr
auto
tm1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I2
);
constexpr
auto
tn1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I3
);
constexpr
auto
tm2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I4
);
constexpr
auto
tn2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I5
);
constexpr
auto
tn3
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I6
);
constexpr
auto
tn4
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I7
);
// get acc0 thread map
constexpr
auto
m0_n_m1_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
tm0
*
tm1
,
tm2
)),
make_pass_through_transform
(
I1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
threadid_to_m0_n_m1_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
tm0
*
tm1
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
,
tm2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
threadid_to_m_n_thread_cluster_adaptor
=
chain_tensor_adaptors
(
m0_n_m1_to_m_n_adaptor
,
threadid_to_m0_n_m1_adaptor
);
// get acc0 2D thread cluster & 2D thread slice
constexpr
auto
thread_cluster_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
tm0
*
tm1
*
tm2
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
));
constexpr
auto
thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
*
m1
*
m2
,
n0
*
n1
*
n2
*
n3
*
n4
));
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
// Initialize running sum and max of exponentiating row vectors
using
SoftmaxBuf
=
typename
decltype
(
blockwise_softmax
)
::
BufferType
;
SoftmaxBuf
running_sum
,
running_sum_new
,
running_max
,
running_max_new
;
running_sum
=
0
;
running_sum_new
=
0
;
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
/*******************************************************************************/
// set up Gemm1
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_
bk0_n_bk
1
=
GetB1BlockDescriptor_
BK
0PerBlock_NPerBlock_
BK
1
();
constexpr
auto
b1_block_desc_
l0perblock_nperblock_l
1
=
GetB1BlockDescriptor_
L
0PerBlock_NPerBlock_
L
1
();
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
Float
Gemm
Acc
,
FloatA
B
,
FloatAcc
,
FloatA
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -574,7 +543,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
...
@@ -605,142 +574,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
// reuse LDS space for gemm0's b_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a1_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b1_block_desc_bk0_n_bk1
)),
MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm1KPack
,
true
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
FloatAB
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
acc1_thread_buf
=
gemm1_blockwise_gemm
.
GetCThreadBuffer
();
//
// Blockwise softmax
//
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
SharedMemTrait
::
reduction_space_size_aligned
);
// get acc0 8D thread cluster
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
tm0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I0
);
constexpr
auto
tn0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I1
);
constexpr
auto
tm1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I2
);
constexpr
auto
tn1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I3
);
constexpr
auto
tm2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I4
);
constexpr
auto
tn2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I5
);
constexpr
auto
tn3
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I6
);
constexpr
auto
tn4
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I7
);
// get acc0 thread map
constexpr
auto
m0_n_m1_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
tm0
*
tm1
,
tm2
)),
make_pass_through_transform
(
I1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
threadid_to_m0_n_m1_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
tm0
*
tm1
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
,
tm2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
threadid_to_m_n_thread_cluster_adaptor
=
chain_tensor_adaptors
(
m0_n_m1_to_m_n_adaptor
,
threadid_to_m0_n_m1_adaptor
);
// get acc0 2D thread cluster & 2D thread slice
constexpr
auto
thread_cluster_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
tm0
*
tm1
*
tm2
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
));
constexpr
auto
thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
*
m1
*
m2
,
n0
*
n1
*
n2
*
n3
*
n4
));
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
),
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
auto
blockwise_gemm1
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
BlockSize
,
FloatA
,
FloatB
,
FloatAcc
,
decltype
(
a1_thread_desc_k0perblock_mperblock_k1
),
decltype
(
b1_block_desc_k0perblock_nperblock_k1
),
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)};
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// Initialize C
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
c_thread_buf
.
Clear
();
// Initialize running sum and max of exponentiating row vectors
using
SoftmaxBuf
=
typename
decltype
(
blockwise_softmax
)
::
BufferType
;
SoftmaxBuf
running_sum
,
running_sum_new
,
running_max
,
running_max_new
;
running_sum
=
0
;
running_sum_new
=
0
;
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// gemm1 K loop
/*******************************************************************************/
// Flash Attention
// Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
index_t
gemm1_k_block_outer_index
=
0
;
do
{
auto
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
{
continue
;
}
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
acc_thread_buf
,
num_k_block_main_loop
);
// Outer loop, along GEMM_L
// Inner loop, along GEMM_K
do
{
// gemm0 start, A-B swaped
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
a_block_desc_k0perblock_mperblock_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b0_grid_desc_k0_l_k1
,
b_block_desc_k0perblock_nperblock_k1
,
b_blockwise_copy
,
b0_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
acc_thread_buf
,
K0BlockMainLoop
);
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
...
...
@@ -797,13 +680,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
});
}
else
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// gemm0 end
// Tiled softmax start
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
...
...
@@ -814,7 +699,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
mathext
::
exp
(
max
-
running_max_new
)
*
sum
;
// Intra-Row data permutation, make swizzled A input for WMMA
__builtin_amdgcn_permlane16
(
0xeca86420
,
0xfdb97531
);
// Low/high row move data to low/high half of thread buffer
/* thread copy*/
// Inter-Row data permutation, fullfill data duplication requirement
__builtin_amdgcn_permlanex16
(
0x76543210
,
0xfedcba98
);
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
...
...
@@ -841,6 +732,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
// Data cast from FloatAcc to FloatA happen here
a1_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
i
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
acc_thread_buf
,
...
...
@@ -879,14 +771,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
// end gemm1
// workaround compiler issue; see ck/ck.hpp
if
constexpr
(
CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE
==
1
&&
is_same_v
<
FloatAB
,
bhalf_t
>
&&
MPerBlock
==
256
&&
NPerBlock
==
128
&&
Gemm1NPerBlock
==
128
)
{
__builtin_amdgcn_sched_barrier
(
0
);
}
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
gemm1_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
cm0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
...
...
@@ -910,8 +794,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc
c_new
=
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
running_sum_new
[
iM
];
// Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
running_sum_new
[
iM
];
c_thread_buf
(
I
)
=
c_new
;
// O_new
});
...
...
@@ -927,120 +810,102 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_sum
=
running_sum_new
;
block_sync_lds
();
// wait for gemm1 LDS read
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
//
shuffle C and write out
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
/*******************************************************************************/
//
write out to C, implement shuffle
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
Gemm1NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
// T
ODO: hacky, fix it!
constexpr
auto
c_
thread
_desc_m
0_n0_m1_n1_m2_n2_n3_n4
=
gemm1_
blockwise_gemm
.
GetC
Thread
Descriptor_M
0_N0_M1_N1_M2_N2_N3_N4
();
// T
his API Provide All dimension (size) you need
constexpr
auto
c_
block
_desc_m
repeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
blockwise_gemm
.
GetC
Block
Descriptor_M
Repeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
gemm1_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I1
);
constexpr
auto
MSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I2
);
constexpr
auto
NWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I4
);
constexpr
auto
NThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I5
);
constexpr
auto
MAccVgprs
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I6
);
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I5
);
constexpr
auto
N3
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I6
);
constexpr
auto
N4
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_m
block_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
c_shuffle_block_desc_m
shrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m
0_n0_m1_n1_m2_n2_n3_n4
=
transform_tensor_descriptor
(
c_shuffle_block_desc_m
block_mperblock_nblock_nperblock
,
constexpr
auto
c_block_desc_m
repeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
c_shuffle_block_desc_m
shrepeat_mpershrepeat_nshrepeat_npershrepeat
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
)),
// M2 = MPerXdl
Number
<
CShuffleMRepeatPerShuffle
>
{},
// MRepeat per shuffle repeat
MWave
,
// MWave
MSubGroup
,
// MSubGroup * MAccVgprs = MPerWmma
MAccVgprs
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
,
// N2 * N3 * N4 = NPerXdl
N3
,
N4
))),
Number
<
CShuffleNRepeatPerShuffle
>
{},
// NRepeat per shuffle repeat
NWave
,
// NWave
NThreadPerSubGroup
))),
// NThreadPerSubGroup = NPerWmma
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
1
,
2
,
6
>
{},
Sequence
<>
{},
Sequence
<
3
,
4
,
5
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
gemm1_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
const
auto
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MRepeat
,
MWave
,
MSubGroup
,
MAccVgprs
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
NRepeat
,
NWave
,
NThreadPerSubGroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
Float
Gemm
Acc
,
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
I1
,
I1
,
CShuffleNRepeatPerShuffle
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
MAccVgprs
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
1
,
// vector write pixel
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m
0_n0_m1_n1_m2_n2_n3_n4
,
c_block_desc_m
repeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
0
,
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I4
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
m_thread_data_on_block_idx
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
...
...
@@ -1048,47 +913,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleM
XdlPerWave
PerShuffle
*
MWave
*
MPer
Xdl
,
CShuffleM
Repeat
PerShuffle
*
MWave
*
MPer
Wmma
,
1
,
CShuffleN
XdlPerWave
PerShuffle
*
NWave
*
NPer
Xdl
>
,
// BlockSliceLengths,
CShuffleN
Repeat
PerShuffle
*
NWave
*
NPer
Wmma
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_m
block_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_m
shrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_m
block_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_m
shrepeat_mpershrepeat_nshrepeat_npershrepeat
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
Gemm1NXdlPerWave
,
1
,
1
,
1
,
N2
,
1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
SpaceFillingCurve
<
Sequence
<
MRepeat
,
1
,
1
,
NRepeat
,
1
,
1
,
MAccVgprs
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
Sequence
<
CShuffleMRepeatPerShuffle
,
1
,
1
,
CShuffleNRepeatPerShuffle
,
1
,
N2
,
1
,
N4
>>
{};
MAccVgprs
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
Gemm1
NPerBlock
>
,
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleM
XdlPerWave
PerShuffle
*
MWave
*
MPer
Xdl
,
CShuffleM
Repeat
PerShuffle
*
MWave
*
MPer
Wmma
,
1
,
CShuffleN
XdlPerWave
PerShuffle
*
NWave
*
NPer
Xdl
>>
{};
CShuffleN
Repeat
PerShuffle
*
NWave
*
NPer
Wmma
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
...
...
@@ -1099,10 +964,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m
0_n0_m1_n1_m2_n2_n3_n4
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m
repeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m
0_n0_m1_n1_m2_n2_n3_n4
,
c_block_desc_m
repeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
...
...
@@ -1110,7 +975,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_m
block_mperblock_nblock_nperblock
,
c_shuffle_block_desc_m
shrepeat_mpershrepeat_nshrepeat_npershrepeat
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
...
...
@@ -1118,13 +983,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
// clang-format on
}
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
a0a469e4
...
...
@@ -414,7 +414,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
_FIFO
<
BlockSize
,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
BlockSize
,
FloatA
,
FloatB
,
FloatAcc
,
...
...
include/ck/utility/amd_wmma.hpp
View file @
a0a469e4
...
...
@@ -23,11 +23,11 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
//
reg_c.template AsType<float8_t>()(Number<0>{}) =
//
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
//
AsType<float8_t>()[Number<0>{}]);
//
amd_assembly_wmma_f32_16x16x16_f16_w32(
//
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment