Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5df713ef
"example/vscode:/vscode.git/clone" did not exist on "fe274f57efd7324d8cb8b9bd6be4832a393cb9c4"
Commit
5df713ef
authored
Feb 11, 2023
by
aska-0096
Browse files
save progress
parent
a6b2f1c1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
923 additions
and
767 deletions
+923
-767
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+31
-29
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+70
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+365
-443
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+405
-284
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+2
-2
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+49
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
5df713ef
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_
o
= Softmax(A_g_m_k * B0_g_k_
n
) * B1_g_
n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_
n
= Softmax(A_g_m_k * B0_g_k_
l
) * B1_g_
l_n
|-----------------|
|-----------------|
Gemm0
Gemm0
|-------------------------------------|
|-------------------------------------|
...
@@ -39,7 +39,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -39,7 +39,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
F32
;
using
Acc0DataType
=
F32
;
using
Acc1DataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -67,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -67,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_
Xdl
_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_
Wmma
_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -76,11 +77,12 @@ using DeviceGemmInstance =
...
@@ -76,11 +77,12 @@ using DeviceGemmInstance =
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
Acc
1
DataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
AElementOp
,
AElementOp
,
B0ElementOp
,
B0ElementOp
,
Acc0ElementOp
,
Acc0ElementOp
,
...
@@ -91,21 +93,21 @@ using DeviceGemmInstance =
...
@@ -91,21 +93,21 @@ using DeviceGemmInstance =
TensorSpecB0
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecB1
,
TensorSpecC
,
TensorSpecC
,
1
,
256
,
256
,
128
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
128
,
// LPerBlock
32
,
// KPerBlock
4
,
// K0PerBlock
64
,
// Gemm1NPerBlock
8
,
// K1
32
,
// Gemm1KPerBlock
64
,
// NPerBlock
8
,
// AK1
4
,
// L0PerBlock
8
,
// BK1
8
,
// L1
2
,
// B1K1
16
,
// MPerWMMA
32
,
// MPerXDL
16
,
// LPerWMMA
32
,
// NPerXDL
16
,
// NPerWMMA
1
,
// MXdlPerWave
//Per repeat = wave_m = wave_num, wave_n = 1
4
,
// NXdlPerWave
1
,
// MRepeat
2
,
// Gemm1NXdlPerWave
8
,
// LRepeat
4
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -113,44 +115,44 @@ using DeviceGemmInstance =
...
@@ -113,44 +115,44 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
4
,
64
,
1
>
,
// B
0
BlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
true
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
4
,
64
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
0
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
8
,
2
,
8
,
false
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
B0DataType
,
AccDataType
,
Acc
0
DataType
,
AccDataType
,
Acc
1
DataType
,
AElementOp
,
AElementOp
,
B0ElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
Acc0ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
Acc
0
DataType
,
ADataType
,
Acc
0
DataType
>
;
// Ref Gemm1: fp16 in, fp16 out
// Ref Gemm1: fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
AccDataType
,
Acc
1
DataType
,
AElementOp
,
AElementOp
,
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
5df713ef
...
@@ -198,7 +198,7 @@ int run(int argc, char* argv[])
...
@@ -198,7 +198,7 @@ int run(int argc, char* argv[])
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc
0
DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
5df713ef
...
@@ -129,11 +129,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
...
@@ -129,11 +129,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
using
Tuple5
=
decltype
(
CalculateAThreadOriginDataIndex
());
// using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
(
// __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
// Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
// Tuple4 b_origin = CalculateBThreadOriginDataIndex())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
// : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
()
{
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
...
@@ -303,8 +304,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
...
@@ -303,8 +304,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
AThreadCopy
a_thread_copy_
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
;
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
// AThreadCopy a_thread_copy_;
// BThreadCopy b_thread_copy_;
};
};
// block wise level pipe designed for inline asm
// block wise level pipe designed for inline asm
...
@@ -425,6 +428,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -425,6 +428,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
blk_idx
[
I0
],
waveId_m
,
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
()
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
()
{
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
...
@@ -438,6 +460,30 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -438,6 +460,30 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
"wrong!"
);
}
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
...
@@ -483,6 +529,23 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -483,6 +529,23 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Provide dimension size
// Provide dimension size
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
5df713ef
...
@@ -22,186 +22,97 @@ namespace ck {
...
@@ -22,186 +22,97 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
// Computes C = A * B0 * B1
typename
FloatAB
,
// MN = MK * KL * LN
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
,
c0_matrix_mask
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
// ^^^^^^^^^^^ (Acc1)
template
<
index_t
NumDimG
,
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimM
,
index_t
NumDim
N
,
index_t
NumDim
L
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDim
O
,
// NumDimGemm1
N
index_t
NumDimN
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
B
0
DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
typename
Acc0DataType
,
typename
Acc1BiasDataType
,
typename
Acc1BiasDataType
,
typename
Gemm
AccDataType
,
typename
Acc
1
DataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
B
0
ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
B
0
Spec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
TensorSpecialization
CSpec
,
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
index_t
MPerBlock
,
ck
::
index_t
LPerBlock
,
index_t
NPerBlock
,
// Gemm0NPerBlock
ck
::
index_t
K0PerBlock
,
// K0 * K1 = Gemm0 GEMM_K Dim
index_t
KPerBlock
,
// Gemm0KPerBlock
ck
::
index_t
K1
,
//
index_t
Gemm1NPerBlock
,
ck
::
index_t
NPerBlock
,
index_t
Gemm1KPerBlock
,
ck
::
index_t
L0PerBlock
,
index_t
AK1
,
ck
::
index_t
L1
,
index_t
BK1
,
ck
::
index_t
MPerWMMA
,
index_t
B1K1
,
ck
::
index_t
LPerWMMA
,
index_t
MPerXDL
,
ck
::
index_t
NPerWMMA
,
index_t
NPerXDL
,
ck
::
index_t
MRepeat
,
index_t
MXdlPerWave
,
ck
::
index_t
LRepeat
,
index_t
NXdlPerWave
,
ck
::
index_t
NRepeat
,
index_t
Gemm1NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_
A
K1
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsExtraM
,
bool
ABlockLds
Add
ExtraM
,
typename
BBlockTransferThreadClusterLengths_
B
K0_
N_B
K1
,
typename
B
0
BlockTransferThreadClusterLengths_K0_
L_
K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
B
0
BlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
B
0
BlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
B
0
BlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
B
0
BlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_
B
K1
,
ck
::
index_t
B
0
BlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsExtra
N
,
bool
B
0
BlockLds
Add
Extra
L
,
typename
B1BlockTransferThreadClusterLengths_
BK
0_N_
BK
1
,
typename
B1BlockTransferThreadClusterLengths_
L
0_N_
L
1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
ck
::
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
ck
::
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_
BK
1
,
ck
::
index_t
B1BlockTransferDstScalarPerVector_
L
1
,
bool
B1BlockLdsExtraN
,
bool
B1BlockLds
Add
ExtraN
,
index_t
CShuffleM
XdlPerWave
PerShuffle
,
index_t
CShuffleM
Repeat
PerShuffle
,
index_t
CShuffleN
XdlPerWave
PerShuffle
,
index_t
CShuffleN
Repeat
PerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
ck
::
index_t
NumPrefetch
=
1
,
struct
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDim
N
,
NumDim
L
,
NumDimK
,
NumDimK
,
NumDim
O
,
NumDim
N
,
ADataType
,
ADataType
,
BDataType
,
B
0
DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
B
0
ElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
MaskingSpec
>
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDim
N
>
0
&&
NumDimK
>
0
&&
NumDim
O
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDim
L
>
0
&&
NumDimK
>
0
&&
NumDim
N
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
...
@@ -210,64 +121,69 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -210,64 +121,69 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// TODO ANT: implement bias combination
// TODO ANT: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
#if 0
// TODO ANT: use alias
static
constexpr
index_t
NumDimGemm0M
=
NumDimM
;
static
constexpr
index_t
NumDimGemm0M
=
NumDimM
;
static constexpr index_t NumDimGemm0N = NumDim
N
;
static
constexpr
index_t
NumDimGemm0N
=
NumDim
L
;
static
constexpr
index_t
NumDimGemm0K
=
NumDimK
;
static
constexpr
index_t
NumDimGemm0K
=
NumDimK
;
static
constexpr
index_t
NumDimGemm1M
=
NumDimM
;
static
constexpr
index_t
NumDimGemm1M
=
NumDimM
;
static constexpr index_t NumDimGemm1N = NumDimO;
static
constexpr
index_t
NumDimGemm1N
=
NumDimN
;
static constexpr index_t NumDimGemm1K = NumDimN;
static
constexpr
index_t
NumDimGemm1K
=
NumDimL
;
#endif
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
K1
;
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute_
Xdl
_CShuffle
;
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute_
Wmma
_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDim
N
,
NumDimK
,
NumDim
O
>
,
Sequence
<
NumDimG
,
NumDimM
,
NumDim
L
,
NumDimK
,
NumDim
N
>
,
Sequence
<
MPerBlock
,
N
PerBlock
,
KPerBlock
,
Gemm1
NPerBlock
>
,
Sequence
<
MPerBlock
,
L
PerBlock
,
KPerBlock
,
NPerBlock
>
,
GemmSpec
,
GemmSpec
,
ASpec
,
ASpec
,
BSpec
,
B
0
Spec
,
B1Spec
,
B1Spec
,
CSpec
>
;
CSpec
>
;
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
A
K1
>
{});
Number
<
K1
>
{});
}
}
static
auto
MakeBGridDescriptor_BK0_
N
_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_
n
s_ks_lengths_vec
,
static
auto
MakeB
0
GridDescriptor_BK0_
L
_BK1
(
const
std
::
vector
<
index_t
>&
b
0
_gs_
l
s_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_
n
s_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
b
0
_gs_
l
s_ks_strides_vec
)
{
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_
n
s_ks_lengths_vec
,
b_gs_
n
s_ks_strides_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b
0
_gs_
l
s_ks_lengths_vec
,
b
0
_gs_
l
s_ks_strides_vec
),
Number
<
B
K1
>
{});
Number
<
K1
>
{});
}
}
static
auto
static
auto
MakeB1GridDescriptor_B
K
0_N_B
K
1
(
const
std
::
vector
<
index_t
>&
b1_gs_
gemm1ns_gemm1k
s_lengths_vec
,
MakeB1GridDescriptor_B
L
0_N_B
L
1
(
const
std
::
vector
<
index_t
>&
b1_gs_
ns_l
s_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_
gemm1ns_gemm1k
s_strides_vec
)
const
std
::
vector
<
index_t
>&
b1_gs_
ns_l
s_strides_vec
)
{
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_
gemm1ns_gemm1k
s_lengths_vec
,
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_
ns_l
s_lengths_vec
,
b1_gs_
gemm1ns_gemm1k
s_strides_vec
),
b1_gs_
ns_l
s_strides_vec
),
Number
<
B1K
1
>
{});
Number
<
L
1
>
{});
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_
N
_BK1
=
decltype
(
MakeBGridDescriptor_BK0_
N
_BK1
({},
{}));
using
B
0
GridDesc_BK0_
L
_BK1
=
decltype
(
MakeB
0
GridDescriptor_BK0_
L
_BK1
({},
{}));
using
B1GridDesc_B
K
0_N_B
K
1
=
decltype
(
MakeB1GridDescriptor_B
K
0_N_B
K
1
({},
{}));
using
B1GridDesc_B
L
0_N_B
L
1
=
decltype
(
MakeB1GridDescriptor_B
L
0_N_B
L
1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_
N
_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B
0
GridDesc_G_
L
_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_
K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_
L
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
constexpr
static
auto
make_MaskOutPredicate
()
...
@@ -286,12 +202,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -286,12 +202,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct
ComputeBasePtrOfStridedBatch
struct
ComputeBasePtrOfStridedBatch
{
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_
N
_K
&
b_grid_desc_g_
n
_k
,
const
B
0
GridDesc_G_
L
_K
&
b
0
_grid_desc_g_
l
_k
,
const
B1GridDesc_G_N_
K
&
b1_grid_desc_g_n_
k
,
const
B1GridDesc_G_N_
L
&
b1_grid_desc_g_n_
l
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_
n
_k_
(
b_grid_desc_g_
n
_k
),
b
0
_grid_desc_g_
l
_k_
(
b
0
_grid_desc_g_
l
_k
),
b1_grid_desc_g_n_
k
_
(
b1_grid_desc_g_n_
k
),
b1_grid_desc_g_n_
l
_
(
b1_grid_desc_g_n_
l
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
{
{
}
}
...
@@ -301,14 +217,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -301,14 +217,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetB
0
BasePtr
(
index_t
g_idx
)
const
{
{
return
b_grid_desc_g_
n
_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
b
0
_grid_desc_g_
l
_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
{
return
b1_grid_desc_g_n_
k
_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
b1_grid_desc_g_n_
l
_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
...
@@ -318,208 +234,202 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -318,208 +234,202 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
private:
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_
N
_K
b_grid_desc_g_
n
_k_
;
B
0
GridDesc_G_
L
_K
b
0
_grid_desc_g_
l
_k_
;
B1GridDesc_G_N_
K
b1_grid_desc_g_n_
k
_
;
B1GridDesc_G_N_
L
b1_grid_desc_g_n_
l
_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
};
};
// GridwiseGemm
// GridwiseOp
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseOp
=
GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
// DataType Family
GemmAccDataType
,
ADataType
,
B0DataType
,
Acc0DataType
,
B1DataType
,
Acc1DataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
// ElementwiseOp Family
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
B
0
ElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_
N
_BK1
,
B
0
GridDesc_BK0_
L
_BK1
,
B1GridDesc_B
K
0_N_B
K
1
,
B1GridDesc_B
L
0_N_B
L
1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
// Tiling Family
BlockSize
,
MPerBlock
,
MPerBlock
,
LPerBlock
,
K0PerBlock
,
// K0 * K1 = Gemm0 GEMM_K Dim
K1
,
//
NPerBlock
,
NPerBlock
,
KPerBlock
,
L0PerBlock
,
Gemm1NPerBlock
,
L1
,
Gemm1KPerBlock
,
MPerWMMA
,
AK1
,
LPerWMMA
,
BK1
,
NPerWMMA
,
B1K1
,
MRepeat
,
MPerXDL
,
LRepeat
,
NPerXDL
,
NRepeat
,
MXdlPerWave
,
// ThreadCluster Family
NXdlPerWave
,
BlockSize
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_
A
K1
,
ABlockTransferDstScalarPerVector_K1
,
true
,
true
,
ABlockLdsExtraM
,
ABlockLds
Add
ExtraM
,
BBlockTransferThreadClusterLengths_
B
K0_
N_B
K1
,
B
0
BlockTransferThreadClusterLengths_K0_
L_
K1
,
BBlockTransferThreadClusterArrangeOrder
,
B
0
BlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
B
0
BlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
B
0
BlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
B
0
BlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_
B
K1
,
B
0
BlockTransferDstScalarPerVector_K1
,
true
,
true
,
BBlockLdsExtra
N
,
B
0
BlockLds
Add
Extra
L
,
B1BlockTransferThreadClusterLengths_
BK
0_N_
BK
1
,
B1BlockTransferThreadClusterLengths_
L
0_N_
L
1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_
BK
1
,
B1BlockTransferDstScalarPerVector_
L
1
,
false
,
false
,
B1BlockLdsExtraN
,
B1BlockLds
Add
ExtraN
,
CShuffleM
XdlPerWave
PerShuffle
,
CShuffleM
Repeat
PerShuffle
,
CShuffleN
XdlPerWave
PerShuffle
,
CShuffleN
Repeat
PerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
// Argument
// Argument
// FIXME: constness
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
Argument
(
const
ADataType
*
p_a_grid
,
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
B
0
DataType
*
p_b
0
_grid
,
const
B1DataType
*
p_b1_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_
n
s_ks_lengths
,
const
std
::
vector
<
index_t
>&
b
0
_gs_
l
s_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_
n
s_ks_strides
,
const
std
::
vector
<
index_t
>&
b
0
_gs_
l
s_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_
gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_
ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_
gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_
ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_
gemm1
ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_
gemm1
ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_
n
s_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_
l
s_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_
n
s_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_
l
s_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_gemm1ns_lengths
,
//
acc1_biases_gs_ms_
o
s_
lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_
n
s_
strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
index_t
M01
,
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
const
index_t
N01
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
B
0
ElementwiseOperation
b
0
_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b
0
_grid_
{
p_b
0
_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_
n
_bk1_
{
b
0
_grid_desc_bk0_
l
_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_
N
_BK1
(
b_gs_
n
s_ks_lengths
,
b_gs_
n
s_ks_strides
)},
DeviceOp
::
MakeB
0
GridDescriptor_BK0_
L
_BK1
(
b
0
_gs_
l
s_ks_lengths
,
b
0
_gs_
l
s_ks_strides
)},
b1_grid_desc_b
k
0_n_b
k
1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_grid_desc_b
l
0_n_b
l
1_
{
b1_gs_gemm1ns_gemm1k
s_lengths
,
b1_gs_
gemm1ns_gemm1k
s_strides
)},
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_gs_ns_l
s_lengths
,
b1_gs_
ns_l
s_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_m_n_
{
c_gs_ms_gemm1
ns_strides
)},
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_
ns_strides
)},
a_grid_desc_g_m_k_
{
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_g_
n
_k_
{
b
0
_grid_desc_g_
l
_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b_gs_
n
s_ks_lengths
,
b_gs_
n
s_ks_strides
)},
Transform
::
MakeB0GridDescriptor_G_N_K
(
b
0
_gs_
l
s_ks_lengths
,
b
0
_gs_
l
s_ks_strides
)},
b1_grid_desc_g_n_
k
_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_grid_desc_g_n_
l
_
{
b1_gs_gemm1ns_gemm1k
s_lengths
,
b1_gs_
gemm1ns_gemm1k
s_strides
)},
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_ns_l
s_lengths
,
b1_gs_
ns_l
s_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
c_gs_ms_gemm1
ns_strides
)},
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_
ns_strides
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
Gridwise
Gemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
block_2_ctile_map_
{
Gridwise
Op
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b
0
_element_op_
{
b
0
_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_
n
_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
b
0
_grid_desc_g_
l
_k_
.
GetLength
(
I1
)},
raw_lengths_mz_
n
z_kz_
gemm1
nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
raw_lengths_mz_
l
z_kz_nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_
n
s_ks_lengths
[
NumDimG
+
NumDim
N
-
1
],
b
0
_gs_
l
s_ks_lengths
[
NumDimG
+
NumDim
L
-
1
],
b_gs_
n
s_ks_lengths
[
NumDimG
+
NumDim
N
+
NumDimK
-
1
],
b
0
_gs_
l
s_ks_lengths
[
NumDimG
+
NumDim
L
+
NumDimK
-
1
],
b1_gs_
gemm1ns_gemm1k
s_lengths
[
NumDimG
+
NumDim
O
-
1
]},
b1_gs_
ns_l
s_lengths
[
NumDimG
+
NumDim
N
-
1
]},
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
b
_n
z_kz_strides_
{
b_gs_
n
s_ks_strides
[
NumDimG
+
NumDim
N
-
1
],
b
0_l
z_kz_strides_
{
b
0
_gs_
l
s_ks_strides
[
NumDimG
+
NumDim
L
-
1
],
b
_gs_
n
s_ks_strides
[
NumDimG
+
NumDim
N
+
NumDimK
-
1
]},
b0
_gs_
l
s_ks_strides
[
NumDimG
+
NumDim
L
+
NumDimK
-
1
]},
b1_nz_
k
z_strides_
{
b1_gs_
gemm1ns_gemm1k
s_strides
[
NumDimG
+
NumDim
O
-
1
],
b1_nz_
l
z_strides_
{
b1_gs_
ns_l
s_strides
[
NumDimG
+
NumDim
N
-
1
],
b1_gs_
gemm1ns_gemm1k
s_strides
[
NumDimG
+
NumDim
O
+
NumDim
N
-
1
]},
b1_gs_
ns_l
s_strides
[
NumDimG
+
NumDim
N
+
NumDim
L
-
1
]},
c_mz_
gemm1
nz_strides_
{
c_gs_ms_
gemm1
ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_mz_nz_strides_
{
c_gs_ms_ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_
gemm1
ns_strides
[
NumDimG
+
NumDimM
+
NumDim
O
-
1
]},
c_gs_ms_ns_strides
[
NumDimG
+
NumDimM
+
NumDim
N
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_
base_ptr
_of_batch_
{
compute_
ptr_offset
_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_
n
_k_
,
b1_grid_desc_g_n_
k
_
,
c_grid_desc_g_m_n_
}
a_grid_desc_g_m_k_
,
b
0
_grid_desc_g_
l
_k_
,
b1_grid_desc_g_n_
l
_
,
c_grid_desc_g_m_n_
}
{
{
// TODO ANT: implement bias addition
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_
n
s_lengths
;
ignore
=
acc0_biases_gs_ms_
l
s_lengths
;
ignore
=
acc0_biases_gs_ms_
n
s_strides
;
ignore
=
acc0_biases_gs_ms_
l
s_strides
;
ignore
=
acc1_biases_gs_ms_
gemm1
ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_
gemm1
ns_strides
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
if
(
Gridwise
Gemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
Gridwise
Op
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_
n
_bk1_
,
b
0
_grid_desc_bk0_
l
_bk1_
,
b1_grid_desc_b
k
0_n_b
k
1_
,
b1_grid_desc_b
l
0_n_b
l
1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
block_2_ctile_map_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
Gridwise
Gemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
Gridwise
Op
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
}
}
}
}
void
Print
()
const
// Pointers
{
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"b1_grid_desc_g_n_k_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"c_grid_desc_g_m_n_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
}
// pointers
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
B
0
DataType
*
p_b
0
_grid_
;
const
B1DataType
*
p_b1_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
//
t
ensor
d
escriptor
//
T
ensor
D
escriptor
s
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_
N
_BK1
b_grid_desc_bk0_
n
_bk1_
;
B
0
GridDesc_BK0_
L
_BK1
b
0
_grid_desc_bk0_
l
_bk1_
;
B1GridDesc_B
K
0_N_B
K
1
b1_grid_desc_b
k
0_n_b
k
1_
;
B1GridDesc_B
L
0_N_B
L
1
b1_grid_desc_b
l
0_n_b
l
1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_
N
_K
b_grid_desc_g_
n
_k_
;
B
0
GridDesc_G_
L
_K
b
0
_grid_desc_g_
l
_k_
;
B1GridDesc_G_N_
K
b1_grid_desc_g_n_
k
_
;
B1GridDesc_G_N_
L
b1_grid_desc_g_n_
l
_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
//
b
lock
-
to
-c-t
ile map
//
B
lock
to
T
ile map
ping
typename
Gridwise
Gemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
Gridwise
Op
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
//
e
lement
-
wise
o
p
//
E
lementwise
O
p
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
B
0
ElementwiseOperation
b
0
_element_op_
;
AccElementwiseOperation
acc_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
...
@@ -527,15 +437,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -527,15 +437,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// check C0 masking and padding
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
C0MatrixMask
c0_matrix_mask_
;
// For robust IsSupportedArgument() check
// Strides for the last M/N/K dimensions of A/B0/B1/C
std
::
vector
<
index_t
>
raw_lengths_mz_nz_kz_gemm1nz_
;
// for sanity check of vector load/store
std
::
vector
<
index_t
>
raw_lengths_mz_lz_kz_nz_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
std
::
vector
<
index_t
>
b
_n
z_kz_strides_
;
std
::
vector
<
index_t
>
b
0_l
z_kz_strides_
;
std
::
vector
<
index_t
>
b1_nz_
k
z_strides_
;
std
::
vector
<
index_t
>
b1_nz_
l
z_strides_
;
std
::
vector
<
index_t
>
c_mz_
gemm1
nz_strides_
;
std
::
vector
<
index_t
>
c_mz_nz_strides_
;
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// Batch Offset
ComputeBasePtrOfStridedBatch
compute_ptr_offset_of_batch_
;
};
};
// Invoker
// Invoker
...
@@ -545,38 +457,32 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -545,38 +457,32 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
!
DeviceOp
::
IsSupportedArgument
(
arg
))
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
{
throw
std
::
runtime_error
(
"wrong! unsupported argument"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
// Gemm0_K
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseOp
,
GridwiseGemm
,
ADataType
,
ADataType
,
// TODO: distiguish A/B datatype
B0DataType
,
B1DataType
,
CDataType
,
CDataType
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
B0GridDesc_BK0_L_BK1
,
DeviceOp
::
B1GridDesc_BL0_N_BL1
,
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
B
0
ElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
typename
GridwiseOp
::
DefaultBlock2CTileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -584,36 +490,32 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -584,36 +490,32 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b
0
_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b
0
_element_op_
,
arg
.
acc_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
c0_matrix_mask_
);
arg
.
c0_matrix_mask_
,
arg
.
block_2_ctile_map_
);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
// to concern Gemm0's loop
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
}
}
// polymorphic
// polymorphic
...
@@ -632,25 +534,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -632,25 +534,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if DEBUG_LOG
if
(
ck
::
get_device_name
()
==
"gfx1100"
)
arg
.
Print
();
{
#endif
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
return
false
;
}
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
constexpr
(
!
(
is_same_v
<
Acc1DataType
,
float
>
||
is_same_v
<
Acc1DataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
// TODO ANT: Check if tensor specialization & strides mismatch
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
return
false
;
}
// Check if C permute dimension matches GEMM + GEMM shape
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_
gemm1
n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
c_n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_
gemm1
n
=
arg
.
b1_grid_desc_b
k
0_n_b
k
1_
.
GetLength
(
I1
);
const
index_t
b1_n
=
arg
.
b1_grid_desc_b
l
0_n_b
l
1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_
gemm1
n
==
b1_
gemm1
n
))
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_n
==
b1_n
))
{
{
return
false
;
return
false
;
}
}
...
@@ -658,19 +575,19 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -658,19 +575,19 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
arg
.
raw_lengths_mz_
n
z_kz_
gemm1
nz_
[
0
];
const
auto
MzRaw
=
arg
.
raw_lengths_mz_
l
z_kz_nz_
[
0
];
const
auto
N
zRaw
=
arg
.
raw_lengths_mz_
n
z_kz_
gemm1
nz_
[
1
];
const
auto
L
zRaw
=
arg
.
raw_lengths_mz_
l
z_kz_nz_
[
1
];
const
auto
KzRaw
=
arg
.
raw_lengths_mz_
n
z_kz_
gemm1
nz_
[
2
];
const
auto
KzRaw
=
arg
.
raw_lengths_mz_
l
z_kz_nz_
[
2
];
const
auto
Gemm1
NzRaw
=
arg
.
raw_lengths_mz_
n
z_kz_
gemm1
nz_
[
3
];
const
auto
NzRaw
=
arg
.
raw_lengths_mz_
l
z_kz_nz_
[
3
];
// Check scalar per vector requirement
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
N
zRaw
;
const
auto
b
0
_extent_lowest
=
B
0
BlockTransferSrcVectorDim
==
2
?
KzRaw
:
L
zRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
N
zRaw
:
Gemm1
NzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
L
zRaw
:
NzRaw
;
const
auto
c_extent_lowest
=
Gemm1
NzRaw
;
const
auto
c_extent_lowest
=
NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b
0
_extent_lowest
%
B
0
BlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
...
@@ -680,24 +597,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -680,24 +597,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Check vector load/store requirement
// Check vector load/store requirement
const
auto
a_stride_lowest
=
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
const
auto
b
0
_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
arg
.
b
_n
z_kz_strides_
[
1
]
:
arg
.
b
_n
z_kz_strides_
[
0
];
B
0
BlockTransferSrcVectorDim
==
2
?
arg
.
b
0_l
z_kz_strides_
[
1
]
:
arg
.
b
0_l
z_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_
k
z_strides_
[
1
]
:
arg
.
b1_nz_
k
z_strides_
[
0
];
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_
l
z_strides_
[
1
]
:
arg
.
b1_nz_
l
z_strides_
[
0
];
const
auto
c_stride_lowest
=
const
auto
c_stride_lowest
=
arg
.
c_mz_
gemm1
nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be contiguous
arg
.
c_mz_nz_strides_
[
1
];
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
if
(
!
(
a_stride_lowest
==
1
||
b
0
_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
c_stride_lowest
==
1
))
{
{
return
false
;
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
return
true
;
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
}
// polymorphic
// polymorphic
...
@@ -706,114 +619,115 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -706,114 +619,115 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
B
0
DataType
*
p_b
0
,
const
B1DataType
*
p_b1
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
CDataType
*
p_c
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
B
0
ElementwiseOperation
b
0
_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
0
,
p_b1
,
p_b1
,
p_c
,
p_c
,
p_acc0_biases
,
p_acc0_biases
,
p_acc1_biases
,
p_acc1_biases
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b0_gs_ls_ks_lengths
,
b_gs_ns_ks_strides
,
b0_gs_ls_ks_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_ns_ls_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_ns_ls_strides
,
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_ns_lengths
,
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_ns_strides
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ls_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc0_biases_gs_ms_ls_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
a_element_op
,
b_element_op
,
b
0
_element_op
,
acc_element_op
,
acc_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
};
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
// FIXME: constness
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
0
,
const
void
*
p_b1
,
const
void
*
p_b1
,
void
*
p_c
,
void
*
p_c
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
B
0
ElementwiseOperation
b
0
_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
c_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
B
0
DataType
*>
(
p_b
0
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
p_acc0_biases
,
// cast in struct Argument
p_acc0_biases
,
p_acc1_biases
,
// cast in struct Argument
p_acc1_biases
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b0_gs_ls_ks_lengths
,
b_gs_ns_ks_strides
,
b0_gs_ls_ks_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_ns_ls_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_ns_ls_strides
,
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_ns_lengths
,
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_ns_strides
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ls_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc0_biases_gs_ms_ls_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_gemm1ns_strides
,
acc1_biases_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
a_element_op
,
b_element_op
,
b
0
_element_op
,
acc_element_op
,
acc_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
);
c_element_op
);
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
...
@@ -825,25 +739,33 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -825,25 +739,33 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
// clang-format off
str
<<
"DeviceBatchedGemmSoftmaxGemmPermute_
Xdl
_CShuffle"
str
<<
"DeviceBatchedGemmSoftmaxGemmPermute_
Wmma
_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
AK1
<<
", "
<<
K1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
NPerWMMA
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
L0PerBlock
<<
", "
<<
B1K1
<<
", "
<<
L1
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
">"
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
" NumPrefetch: "
<<
"B0Spec"
<<
getTensorSpecializationString
(
BSpec
)
<<
", "
<<
NumPrefetch
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"LoopScheduler: "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
;
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
5df713ef
...
@@ -20,71 +20,106 @@ namespace ck {
...
@@ -20,71 +20,106 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB0
,
typename
FloatB1
,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_K0_N_K1
,
typename
B0GridDesc_BK0_L_BK1
,
typename
B1GridDesc_BL0_N_BL1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_
gemm_wmma
(
kernel_
batched_gemm_softmax_gemm_wmma_cshuffle
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b0_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_K0_N_K1
b0_grid_desc_k0_l_k1
,
const
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1
,
const
B1GridDesc_BL0_N_BL1
b1_grid_desc_l0_n_l1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
B0ElementwiseOperation
b0_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
const
index_t
num_blocks_per_batch
=
p_b0_grid
,
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
p_c_grid
,
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b0_grid
+
b0_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
p_shared
,
a_grid_desc_k0_m_k1
,
a_grid_desc_ak0_m_ak1
,
b0_grid_desc_k0_l_k1
,
b0_grid_desc_bk0_l_bk1
,
b1_grid_desc_l0_n_l1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
a_element_op
,
b_element_op
,
b0_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b0_grid
;
ignore
=
p_b0_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b0_grid_desc_k0_l_k1
;
ignore
=
b0_grid_desc_bk0_l_bk1
;
ignore
=
b1_grid_desc_l0_n_l1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b0_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx1100__))
}
}
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template
<
index_t
BlockSize
,
template
<
typename
FloatA
,
typename
FloatA
,
typename
FloatB0
,
typename
FloatB0
,
typename
FloatAcc0
,
typename
FloatB1
,
typename
FloatB1
,
typename
FloatAcc
,
typename
FloatAcc
1
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -93,26 +128,24 @@ template <index_t BlockSize,
...
@@ -93,26 +128,24 @@ template <index_t BlockSize,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_
A
K0_M_
A
K1
,
typename
B0GridDesc_K0_L_K1
,
typename
B0GridDesc_
B
K0_L_
B
K1
,
typename
B1GridDesc_L0_N_L1
,
typename
B1GridDesc_
B
L0_N_
B
L1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
index_t
Gemm0MPerBlock
,
index_t
MPerBlock
,
index_t
Gemm0LPerBlock
,
index_t
LPerBlock
,
index_t
Gemm0K0PerBlock
,
index_t
K0PerBlock
,
// K0 * K1Value = Gemm0 GEMM_K Dim
index_t
Gemm0K1Value
,
index_t
K1Value
,
index_t
Gemm0MPerWmma
,
index_t
NPerBlock
,
index_t
Gemm0LPerWmma
,
index_t
L0PerBlock
,
index_t
Gemm0MRepeat
,
index_t
L1Value
,
index_t
Gemm0LRepeat
,
index_t
MPerWmma
,
index_t
Gemm1MPerBlock
,
index_t
LPerWmma
,
index_t
Gemm1NPerBlock
,
index_t
NPerWmma
,
index_t
Gemm1L0PerBlock
,
index_t
MRepeat
,
index_t
Gemm1L1Value
,
index_t
LRepeat
,
index_t
Gemm1MPerWmma
,
index_t
NRepeat
,
index_t
Gemm1NPerWmma
,
index_t
BlockSize
,
index_t
Gemm1MRepeat
,
index_t
Gemm1NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
@@ -141,6 +174,8 @@ template <index_t BlockSize,
...
@@ -141,6 +174,8 @@ template <index_t BlockSize,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
index_t
NumGemmKPrefetchStage
=
1
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
...
@@ -155,57 +190,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -155,57 +190,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
// K1Value should be Number<...>
static
constexpr
auto
K1
=
Number
<
Gemm0K1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
K0PerBlock
>
{};
static
constexpr
auto
N1
=
Number
<
Gemm1N1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
K0PerBlock
>
{};
static
constexpr
auto
BK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
L0
=
Number
<
L0PerBlock
>
{};
static
constexpr
auto
L1
=
Number
<
L1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerWmma
*
MRepeat
);
static
constexpr
auto
Gemm0LWaves
=
L0PerBlock
*
L1Value
/
(
LPerWmma
*
LRepeat
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_
A
K0PerBlock_MPerBlock_
A
K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
return
make_naive_tensor_descriptor
(
if
constexpr
(
ABlockLdsExtraM
)
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
{
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_k0perblock_mperblock_k1
;
}
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_
N
PerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetB
0
BlockDescriptor_
B
K0PerBlock_
L
PerBlock_
B
K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
return
make_naive_tensor_descriptor
(
if
constexpr
(
BBlockLdsExtraN
)
make_tuple
(
BK0
,
Number
<
LPerBlock
>
{},
BK1
),
{
make_tuple
(
Number
<
LPerBlock
+
B0BlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
return
make_naive_tensor_descriptor
(
}
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_k0perblock_nperblock_k1
;
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
L0
,
Number
<
NPerBlock
>
{},
L1
),
make_tuple
(
Number
<
NPerBlock
+
B1BlockLdsExtraN
>
{}
*
L1
,
L1
,
I1
));
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -228,55 +250,68 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -228,55 +250,68 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
FloatB0
));
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
const
index_t
gemm1_bytes_end
=
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
FloatB1
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatAcc0
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
(
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
b_block_space_size_aligned
*
sizeof
(
FloatB
));
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_K0_N_K1
&
b0_grid_desc_k0_l_k1
,
const
B0GridDesc_BK0_L_BK1
&
b0_grid_desc_bk0_l_bk1
,
const
B1GridDesc_BL0_N_BL1
&
b1_grid_desc_l0_n_l1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerWmma
*
MRepeat
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerWmma
*
MRepeat
)
==
0
)
&&
(
N
PerBlock
%
(
NRepeat
*
NPerWmma
))
==
0
,
(
L
PerBlock
%
(
LPerWmma
*
LRepeat
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b0_grid_desc_k0_l_k1
.
GetLength
(
I1
);
const
auto
L
=
b0_grid_desc_bk0_l_bk1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
N
=
b1_grid_desc_l0_n_l1
.
GetLength
(
I1
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
const
auto
KPerBlock
=
K0PerBlock
*
K1Value
;
K0
==
b0
_grid_desc_
k0_l_k1
.
GetLength
(
I0
)
&&
K1
==
a
_grid_desc_
k0_
m_
k1
.
GetLength
(
I
2
)
&&
if
(
!
(
M
==
c
_grid_desc_
m_n
.
GetLength
(
I0
)
&&
N
==
c
_grid_desc_m_
n
.
GetLength
(
I
1
)))
K1
==
b0_grid_desc_k0_l_k1
.
GetLength
(
I2
)))
{
return
false
;
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
L
%
LPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
N
%
NPerBlock
==
0
))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
// check gemm0 gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
return
false
;
return
false
;
}
// check gridwise gemm pipeline
// check gemm1 gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
(
LPerBlock
%
(
L0PerBlock
*
L1Value
)
==
0
))
{
return
false
;
}
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
const
auto
num_gemm1_k_inner_loop
=
LPerBlock
/
(
L0PerBlock
*
L1Value
);
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inner_loop
))
{
{
return
false
;
return
false
;
}
}
...
@@ -292,7 +327,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -292,7 +327,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
Value
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
...
@@ -328,6 +363,42 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -328,6 +363,42 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b0_block_desc_bk0_l_bk1
=
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bl0_n_bl1
=
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
L1
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b0_block_space_size_aligned
=
math
::
integer_least_multiple
(
b0_block_desc_bk0_l_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bl0_n_bl1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b0_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
...
@@ -335,9 +406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -335,9 +406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const
FloatB1
*
__restrict__
p_b1_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_
A
K0_M_
A
K1
&
a_grid_desc_k0_m_k1
,
const
B0GridDesc_K0_L_K1
&
b0_grid_desc_k0_l_k1
,
const
B0GridDesc_
B
K0_L_
B
K1
&
b0_grid_desc_k0_l_k1
,
const
B1GridDesc_L0_N_L1
&
b1_grid_desc_l0_n_l1
,
const
B1GridDesc_
B
L0_N_
B
L1
&
b1_grid_desc_l0_n_l1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -380,9 +451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -380,9 +451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
constexpr
auto
max_lds_align
=
K1
;
//
constexpr auto max_lds_align = K1
Value
;
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_
A
K0PerBlock_MPerBlock_
A
K1
();
constexpr
auto
b_block_desc_k0perblock_
n
perblock_k1
=
GetBBlockDescriptor_K0PerBlock_
N
PerBlock_K1
();
constexpr
auto
b
0
_block_desc_k0perblock_
l
perblock_k1
=
GetB
0
BlockDescriptor_
B
K0PerBlock_
L
PerBlock_
B
K1
();
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
...
@@ -390,7 +461,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -390,7 +461,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
K0
PerBlock
,
MPerBlock
,
K1
>
,
/* typename BlockSliceLengths, */
Sequence
<
A
K0
,
MPerBlock
,
A
K1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename SrcData, */
FloatA
,
...
@@ -415,134 +486,177 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -415,134 +486,177 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b
0
_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
B
0
ElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0
PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
B
K0
,
NPerBlock
,
B
K1
>
,
BBlockTransferThreadClusterLengths_K0_
N
_K1
,
B
0
BlockTransferThreadClusterLengths_K0_
L
_K1
,
BBlockTransferThreadClusterArrangeOrder
,
B
0
BlockTransferThreadClusterArrangeOrder
,
FloatB
,
FloatB
0
,
FloatB
,
FloatB
0
,
decltype
(
b0_grid_desc_k0_l_k1
),
decltype
(
b0_grid_desc_k0_l_k1
),
decltype
(
b_block_desc_k0perblock_
n
perblock_k1
),
decltype
(
b
0
_block_desc_k0perblock_
l
perblock_k1
),
BBlockTransferSrcAccessOrder
,
B
0
BlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
B
0
BlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
B
0
BlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
B
0
BlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
B
0
ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b0_grid_desc_k0_l_k1
,
b0_grid_desc_k0_l_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
0
,
0
),
b_element_op
,
b
0
_element_op
,
b_block_desc_k0perblock_
n
perblock_k1
,
b
0
_block_desc_k0perblock_
l
perblock_k1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
/*******************************************************************************/
/*******************************************************************************/
// Gemm0
// Gemm0
constexpr
auto
WmmaK
=
16
;
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
Value
,
WmmaK
);
auto
blockwise_gemm0
=
auto
blockwise_gemm0
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
FloatA
,
FloatA
,
FloatB
,
FloatB
0
,
FloatAcc
,
FloatAcc
0
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_
n
perblock_k1
),
decltype
(
b
0
_block_desc_k0perblock_
l
perblock_k1
),
MPerWmma
,
MPerWmma
,
N
PerWmma
,
L
PerWmma
,
MRepeat
,
MRepeat
,
N
Repeat
,
L
Repeat
,
KPack
>
{};
KPack
>
{};
// Prepare Register for A*B0 matrix
// Prepare Register for A*B0 matrix
auto
acc_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
acc0_thread_buf
=
blockwise_gemm0
.
GetCThreadBuffer
();
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr
auto
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
blockwise_gemm0
.
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
();
constexpr
auto
mrepeat
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I0
);
constexpr
auto
mwave
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I1
);
constexpr
auto
mthreadpersubgroup
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I2
);
constexpr
auto
lrepeat
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I3
);
constexpr
auto
lwave
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I4
);
constexpr
auto
lsubgroup
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I5
);
constexpr
auto
laccvgprs
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I6
);
constexpr
auto
acc0_thread_desc_l0perblock_mperblock_l1
=
transform_tensor_descriptor
(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
lrepeat
,
lrepeat
,
lsubgroup
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
mrepeat
,
mwave
,
mthreadpersubgroup
)),
make_pass_through_transform
(
laccvgprs
)),
make_tuple
(
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
/*******************************************************************************/
/*******************************************************************************/
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
b0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB0
*>
(
p_shared
)
+
SharedMemTrait
::
b0_block_space_offset
,
b0_block_desc_k0perblock_lperblock_k1
.
GetElementSpaceSize
());
// Shift Per SUB_K
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b0_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
a_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
a_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b0_grid_desc_k0_l_k1
.
GetLength
(
I0
),
LPerBlock
,
0
);
const
auto
b0_block_reset_copy_step
=
make_multi_index
(
-
b0_grid_desc_k0_l_k1
.
GetLength
(
I0
),
LPerBlock
,
0
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
/*******************************************************************************/
/*******************************************************************************/
// softmax
// softmax
/*******************************************************************************/
/*******************************************************************************/
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAcc
*>
(
p_shared
),
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
));
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
// get acc0 8D thread cluster
static_cast
<
FloatAcc0
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
SharedMemTrait
::
reduction_space_size_aligned
);
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
// get acc0 7D thread cluster
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
constexpr
auto
tm0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I0
);
blockwise_gemm0
.
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
().
GetLengths
()
/
constexpr
auto
tn0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I1
);
blockwise_gemm0
.
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
().
GetLengths
();
constexpr
auto
tm1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I2
);
constexpr
auto
t_mrepeat
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I0
);
constexpr
auto
tn1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I3
);
constexpr
auto
t_mwave
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I1
);
constexpr
auto
tm2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I4
);
constexpr
auto
t_mthreadpersubgroup
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I2
);
constexpr
auto
tn2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I5
);
constexpr
auto
t_lrepeat
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I3
);
constexpr
auto
tn3
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I6
);
constexpr
auto
t_lwave
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I4
);
constexpr
auto
tn4
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I7
);
constexpr
auto
t_lsubgroup
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I5
);
constexpr
auto
t_laccvgprs
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I6
);
// get acc0 thread map
// get acc0 thread map
constexpr
auto
m0_
n
_m1_to_m_
n
_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
m0_
l
_m1_to_m_
l
_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
t
m0
*
tm1
,
tm2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
t
_mrepeat
*
t_mwave
,
t_mthreadpersubgroup
)),
make_pass_through_transform
(
I1
)),
make_pass_through_transform
(
I1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
threadid_to_m0_
n
_m1_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
threadid_to_m0_
l
_m1_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
tm0
*
tm1
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
,
tm2
))),
make_merge_transform
(
make_tuple
(
t_mrepeat
*
t_mwave
,
t_lrepeat
*
t_lwave
*
t_lsubgroup
*
t_laccvgprs
,
t_mthreadpersubgroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
threadid_to_
m
_n_thread_cluster_adaptor
=
const
auto
threadid_to_
l
_n_thread_cluster_adaptor
=
chain_tensor_adaptors
(
m0_
n
_m1_to_m_
n
_adaptor
,
threadid_to_m0_
n
_m1_adaptor
);
chain_tensor_adaptors
(
m0_
l
_m1_to_m_
l
_adaptor
,
threadid_to_m0_
l
_m1_adaptor
);
// get acc0 2D thread cluster & 2D thread slice
// get acc0 2D thread cluster & 2D thread slice
constexpr
auto
thread_cluster_desc_m_n
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_cluster_desc_m_l
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
tm0
*
tm1
*
tm2
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
));
make_tuple
(
t_mrepeat
*
t_mwave
*
t_mthreadpersubgroup
,
t_lrepeat
*
t_lwave
*
t_lsubgroup
*
t_laccvgprs
));
constexpr
auto
thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
*
m1
*
m2
,
n0
*
n1
*
n2
*
n3
*
n4
));
constexpr
auto
thread_slice_desc_m_l
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
mrepeat
*
mwave
*
mthreadpersubgroup
,
lrepeat
*
lwave
*
lsubgroup
*
laccvgprs
));
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
Float
Gemm
Acc
,
FloatAcc
0
,
decltype
(
threadid_to_
m
_n_thread_cluster_adaptor
),
decltype
(
threadid_to_
l
_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_
n
),
decltype
(
thread_cluster_desc_m_
l
),
decltype
(
thread_slice_desc_m_
n
)
>
{};
decltype
(
thread_slice_desc_m_
l
)
>
{};
// Initialize running sum and max of exponentiating row vectors
// Initialize running sum and max of exponentiating row vectors
using
SoftmaxBuf
=
typename
decltype
(
blockwise_softmax
)
::
BufferType
;
using
SoftmaxBuf
=
typename
decltype
(
blockwise_softmax
)
::
BufferType
;
SoftmaxBuf
running_sum
,
running_sum_new
,
running_max
,
running_max_new
;
SoftmaxBuf
running_sum
,
running_sum_new
,
running_max
,
running_max_new
;
running_sum
=
0
;
running_sum
=
0
;
running_sum_new
=
0
;
running_sum_new
=
0
;
running_max
=
NumericLimits
<
Float
Gemm
Acc
>::
Lowest
();
running_max
=
NumericLimits
<
FloatAcc
0
>::
Lowest
();
running_max_new
=
NumericLimits
<
Float
Gemm
Acc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatAcc
0
>::
Lowest
();
/*******************************************************************************/
/*******************************************************************************/
// set up Gemm1
// set up Gemm1
/*******************************************************************************/
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_l0perblock_nperblock_l1
=
GetB1BlockDescriptor_L0PerBlock_NPerBlock_L1
();
constexpr
auto
b1_block_desc_l0perblock_nperblock_l1
=
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
();
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
L0PerBlock
,
0
,
0
);
// A1 matrix in VGPR
constexpr
auto
A1ThreadSlice_L0PerBlock_MPerBlock_L1
=
make_tuple
(
Number
<
L0PerBlock
*
L1Value
/
laccvgprs
>
{},
Number
<
mrepeat
*
mwave
*
mthreadpersubgroup
>
{},
Number
<
laccvgprs
>
{});
// Data duplicated dimension
constexpr
auto
A1ThreadSliceL0PerBlock
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I0
];
constexpr
auto
A1ThreadSliceMPerBlock
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I1
];
constexpr
auto
A1ThreadSliceL1
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I2
];
// A1 has duplicated data
constexpr
auto
A1ThreadDuplicatedDim
=
I2
*
A1ThreadSliceL1
;
constexpr
auto
a1_thread_desc_l0perblock_mperblock_l1
=
make_naive_tensor_descriptor
(
make_tuple
(
A1ThreadSliceL0PerBlock
,
A1ThreadSliceMPerBlock
,
A1ThreadDuplicatedDim
),
make_tuple
(
A1ThreadSliceMPerBlock
*
A1ThreadDuplicatedDim
,
A1ThreadDuplicatedDim
,
I1
));
// A1 matrix blockwise copy
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
FloatAcc
,
FloatAcc
0
,
FloatA
,
FloatA
,
decltype
(
acc_thread_desc_
k0_m_k
1
),
decltype
(
acc
0
_thread_desc_
l0perblock_mperblock_l
1
),
decltype
(
a1_thread_desc_
k0_m_k
1
),
decltype
(
a1_thread_desc_
l0perblock_mperblock_l
1
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSlice
K0
,
A1ThreadSliceM
,
A1ThreadSlice
K
1
>
,
Sequence
<
A1ThreadSlice
L0PerBlock
,
A1ThreadSliceM
PerBlock
,
A1ThreadSlice
L
1
>
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
n4
,
laccvgprs
,
// dst Rowlane
// dst Rowlane
// 0x76543210 0xfedcba98
// 0x76543210 0xfedcba98
// src Rowlane
// src Rowlane
...
@@ -551,68 +665,77 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -551,68 +665,77 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// B1 matrix blockwise copy
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
B
0
ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1
NPerBlock
,
B1K
1
>
,
Sequence
<
L0
,
NPerBlock
,
L
1
>
,
B1BlockTransferThreadClusterLengths_
BK
0_N_
BK
1
,
B1BlockTransferThreadClusterLengths_
L
0_N_
L
1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
Float
A
B
,
FloatB
1
,
Float
A
B
,
FloatB
1
,
decltype
(
b1_grid_desc_
bk
0_n_
bk
1
),
decltype
(
b1_grid_desc_
l
0_n_
l
1
),
decltype
(
b1_block_desc_
bk0_n_bk
1
),
decltype
(
b1_block_desc_
l0perblock_nperblock_l
1
),
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcVectorDim
,
2
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_
BK
1
,
B1BlockTransferDstScalarPerVector_
L
1
,
1
,
1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b1_grid_desc_
bk
0_n_
bk
1
,
b1_grid_desc_
l
0_n_
l
1
,
make_multi_index
(
0
,
gemm1_
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_element_op
,
b1_block_desc_
bk0_n_bk
1
,
b1_block_desc_
l0perblock_nperblock_l
1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
),
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
a1_thread_desc_l0perblock_mperblock_l1
.
GetElementSpaceSize
());
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB1
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_l0perblock_nperblock_l1
.
GetElementSpaceSize
());
auto
blockwise_gemm1
=
auto
blockwise_gemm1
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
BlockSize
,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
_FIFO
<
BlockSize
,
FloatA
,
FloatA
,
FloatB
,
FloatB
1
,
FloatAcc
,
FloatAcc
1
,
decltype
(
a1_thread_desc_
k
0perblock_mperblock_
k
1
),
decltype
(
a1_thread_desc_
l
0perblock_mperblock_
l
1
),
decltype
(
b1_block_desc_
k
0perblock_nperblock_
k
1
),
decltype
(
b1_block_desc_
l
0perblock_nperblock_
l
1
),
MPerWmma
,
MPerWmma
,
NPerWmma
,
NPerWmma
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)
};
KPack
>
{};
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
const
index_t
num_gemm1_
k
_block_outer_loop
=
b_grid_desc_
b
k0_
n_b
k1
.
GetLength
(
I1
)
/
N
PerBlock
;
const
index_t
num_gemm1_
l
_block_outer_loop
=
b
0
_grid_desc_k0_
l_
k1
.
GetLength
(
I1
)
/
L
PerBlock
;
constexpr
index_t
num_gemm1_
k
_block_inner_loop
=
N
PerBlock
/
Gemm1K
PerBlock
;
constexpr
index_t
num_gemm1_
l
_block_inner_loop
=
L
PerBlock
/
(
L0
PerBlock
*
L1Value
)
;
// Initialize C
// Initialize C
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
1
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
/*******************************************************************************/
/*******************************************************************************/
// Flash Attention
// Flash Attention
// Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
// Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
index_t
gemm1_
k
_block_outer_index
=
0
;
index_t
gemm1_
l
_block_outer_index
=
0
;
// Outer loop, along GEMM_L
// Outer loop, along GEMM_L
// Inner loop, along GEMM_K
// Inner loop, along GEMM_K
do
{
do
{
auto
l_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm1_l_block_outer_index
*
LPerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
l_block_data_idx_on_grid
,
MPerBlock
,
LPerBlock
))
{
continue
;
}
// gemm0 start, A-B swaped
// gemm0 start, A-B swaped
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
a_block_desc_k0perblock_mperblock_k1
,
a_block_desc_k0perblock_mperblock_k1
,
a_blockwise_copy
,
a_blockwise_copy
,
...
@@ -620,33 +743,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -620,33 +743,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b0_grid_desc_k0_l_k1
,
b0_grid_desc_k0_l_k1
,
b_block_desc_k0perblock_
n
perblock_k1
,
b
0
_block_desc_k0perblock_
l
perblock_k1
,
b_blockwise_copy
,
b
0
_blockwise_copy
,
b0_grid_buf
,
b0_grid_buf
,
b_block_buf
,
b
0
_block_buf
,
b_block_slice_copy_step
,
b
0
_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
0
,
acc_thread_buf
,
acc
0
_thread_buf
,
K0BlockMainLoop
);
K0BlockMainLoop
);
// do MNK padding or upper triangular masking
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
{
//
8
d thread_desc in thread scope
//
7
d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M
0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
blockwise_gemm
0
.
GetCThreadDescriptor_M
Repeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
().
GetLengths
();
//
8
d block_desc in block scope
//
7
d block_desc in block scope
constexpr
auto
c_block_lengths
=
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M
0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
blockwise_gemm
0
.
GetCBlockDescriptor_M
Repeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
MREPEAT
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
MWAVE
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
MTHREADSubGroup
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
LREPEAT
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
LWAVE
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
LSUBGROUP
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
LACCVGPRS
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
// index as well as n-d index
...
@@ -656,36 +778,34 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -656,36 +778,34 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
false
>
;
// SnakeCurved
auto
acc0_thread_origin
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
8
D
(
auto
acc0_thread_origin
=
blockwise_gemm
0
.
CalculateCThreadOriginDataIndex
7
D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
Number
<
0
>
{},
Number
<
0
>
{});
constexpr
auto
block_idx_to_m_
n
_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
block_idx_to_m_
l
_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
0
,
M1
,
M2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
REPEAT
,
MWAVE
,
MTHREADSubGroup
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_unmerge_transform
(
make_tuple
(
LREPEAT
,
LWAVE
,
LSUBGROUP
,
LACCVGPRS
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}));
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
auto
m_local
=
block_idx_to_m_l_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
l_local
=
block_idx_to_m_l_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n
_global
=
n
_local
+
n
_block_data_idx_on_grid
;
auto
l
_global
=
l
_local
+
l
_block_data_idx_on_grid
;
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n
_global
))
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
l
_global
))
{
{
acc_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
acc
0
_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
}
else
else
{
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
acc_element_op
(
acc
0
_thread_buf
(
i
),
acc
0
_thread_buf
[
i
]);
}
}
});
});
}
}
else
else
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
{
static_for
<
0
,
acc
0
_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
[
&
](
auto
i
)
{
acc_element_op
(
acc
0
_thread_buf
(
i
),
acc
0
_thread_buf
[
i
]);
});
}
}
...
@@ -697,7 +817,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -697,7 +817,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
blockwise_softmax
.
Run
(
acc_thread_buf
,
workspace_buf
);
blockwise_softmax
.
Run
(
acc
0
_thread_buf
,
workspace_buf
);
// TODO: may convert to log domain
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_max_new
=
mathext
::
max
(
max
,
running_max
);
...
@@ -717,79 +837,80 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -717,79 +837,80 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
acc1_thread_buf
.
Clear
();
acc1_thread_buf
.
Clear
();
// preload data into LDS
// preload data into LDS
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_
bk
0_n_
bk
1
,
b1_grid_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_
l
0_n_
l
1
,
b1_grid_buf
);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_
bk
0_n_
bk
1
,
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_
l
0_n_
l
1
,
b1_block_slice_copy_step
);
b1_block_slice_copy_step
);
block_sync_lds
();
// wait for reduction LDS read
block_sync_lds
();
// wait for reduction LDS read
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_
bk0_n_bk
1
,
b1_block_buf
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_
l0perblock_nperblock_l
1
,
b1_block_buf
);
// main body
// main body
if
constexpr
(
num_gemm1_
k
_block_inner_loop
>
1
)
if
constexpr
(
num_gemm1_
l
_block_inner_loop
>
1
)
{
{
static_for
<
0
,
num_gemm1_
k
_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_gemm1_
l
_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
// Data cast from FloatAcc to FloatA happen here
// Data cast from FloatAcc
0
to FloatA happen here
a1_blockwise_copy
.
Run
(
acc_thread_desc_
k0_m_k
1
,
a1_blockwise_copy
.
Run
(
acc
0
_thread_desc_
l0perblock_mperblock_l
1
,
make_tuple
(
Number
<
i
*
A1ThreadSlice
K0
>
{},
I0
,
I0
),
make_tuple
(
Number
<
i
*
A1ThreadSlice
L0PerBlock
>
{},
I0
,
I0
),
acc_thread_buf
,
acc
0
_thread_buf
,
a1_thread_desc_
k0_m_k
1
,
a1_thread_desc_
l0perblock_mperblock_l
1
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
a1_thread_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_
bk
0_n_
bk
1
,
b1_grid_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_
l
0_n_
l
1
,
b1_grid_buf
);
block_sync_lds
();
block_sync_lds
();
gemm1_
blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
blockwise_gemm
1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
block_sync_lds
();
block_sync_lds
();
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_
bk
0_n_
bk
1
,
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_
l
0_n_
l
1
,
b1_block_slice_copy_step
);
b1_block_slice_copy_step
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_
bk0_n_bk
1
,
b1_block_buf
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_
l0perblock_nperblock_l
1
,
b1_block_buf
);
});
});
}
}
// tail
// tail
{
{
a1_blockwise_copy
.
Run
(
a1_blockwise_copy
.
Run
(
acc_thread_desc_
k0_m_k
1
,
acc
0
_thread_desc_
l0perblock_mperblock_l
1
,
make_tuple
(
make_tuple
(
Number
<
(
num_gemm1_
k
_block_inner_loop
-
1
)
*
A1ThreadSlice
K0
>
{},
I0
,
I0
),
Number
<
(
num_gemm1_
l
_block_inner_loop
-
1
)
*
A1ThreadSlice
L0PerBlock
>
{},
I0
,
I0
),
acc_thread_buf
,
acc
0
_thread_buf
,
a1_thread_desc_
k0_m_k
1
,
a1_thread_desc_
l0perblock_mperblock_l
1
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
a1_thread_buf
);
block_sync_lds
();
block_sync_lds
();
gemm1_
blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
blockwise_gemm
1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
}
}
}
// end gemm1
}
// end gemm1
constexpr
auto
c_thread_desc_m
0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
c_thread_desc_m
repeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
gemm1_
blockwise_gemm
.
GetCThreadDescriptor_M
0_N0_M1_N1_M2_N2_N3_N4
();
blockwise_gemm
1
.
GetCThreadDescriptor_M
Repeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
();
constexpr
auto
c
m0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
c
_mrepeat
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I0
);
constexpr
auto
c
n0
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
c
_mwave
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I1
);
constexpr
auto
c
m1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
c
_mthreadpersubgroup
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I2
);
constexpr
auto
c
n1
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
c
_nrepeat
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I3
);
constexpr
auto
c
m2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
c
_nwave
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I4
);
constexpr
auto
c
n2
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
c
_nsubgroup
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I5
);
constexpr
auto
c
n3
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
c
_naccvgprs
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I6
);
constexpr
auto
cn4
=
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
c_thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
c_thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
cm0
*
cm1
*
cm2
,
cn0
*
cn1
*
cn2
*
cn3
*
cn4
));
make_tuple
(
c_mrepeat
*
c_mwave
*
c_mthreadpersubgroup
,
c_nrepeat
*
c_nwave
*
c_nsubgroup
*
c_naccvgprs
));
constexpr
auto
c_thread_buf_slice_m
=
c_thread_slice_desc_m_n
.
GetLength
(
I0
);
constexpr
auto
c_thread_buf_slice_m
=
c_thread_slice_desc_m_n
.
GetLength
(
I0
);
constexpr
auto
c_thread_buf_slice_n
=
c_thread_slice_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
c_thread_buf_slice_n
=
c_thread_slice_desc_m_n
.
GetLength
(
I1
);
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
Float
Gemm
Acc
acc1
=
acc1_thread_buf
[
I
];
// P*V
FloatAcc
1
acc1
=
acc1_thread_buf
[
I
];
// P*V
Float
Gemm
Acc
c
=
c_thread_buf
[
I
];
// O
FloatAcc
1
c
=
c_thread_buf
[
I
];
// O
Float
Gemm
Acc
c_new
=
FloatAcc
1
c_new
=
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
running_sum_new
[
iM
];
running_sum_new
[
iM
];
...
@@ -798,26 +919,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -798,26 +919,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
});
});
});
});
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_
a
k0_m_
a
k1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_reset_copy_step
);
// rewind K
a_block_reset_copy_step
);
// rewind K
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_
b
k0_
n_b
k1
,
b
0
_blockwise_copy
.
MoveSrcSliceWindow
(
b
0
_grid_desc_k0_
l_
k1
,
b_block_reset_copy_step
);
// rewind K and step N
b
0
_block_reset_copy_step
);
// rewind K and step N
// update before next j iteration
// update before next j iteration
running_max
=
running_max_new
;
running_max
=
running_max_new
;
running_sum
=
running_sum_new
;
running_sum
=
running_sum_new
;
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
}
while
(
++
gemm1_
k
_block_outer_index
<
num_gemm1_
k
_block_outer_loop
);
}
while
(
++
gemm1_
l
_block_outer_index
<
num_gemm1_
l
_block_outer_loop
);
/*******************************************************************************/
/*******************************************************************************/
// write out to C, implement shuffle
// write out to C, implement shuffle
{
{
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
0
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
// This API Provide All dimension (size) you need
// This API Provide All dimension (size) you need
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
0
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I1
);
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I1
);
constexpr
auto
MSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I2
);
constexpr
auto
MSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I2
);
...
@@ -852,7 +973,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -852,7 +973,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
0
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
...
@@ -877,7 +998,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -877,7 +998,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// shuffle: threadwise copy C from VGPR to LDS
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
1
,
FloatCShuffle
,
FloatCShuffle
,
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
5df713ef
...
@@ -1313,8 +1313,8 @@ template <typename SrcData,
...
@@ -1313,8 +1313,8 @@ template <typename SrcData,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
in
dex
_t
LowEightRowlaneIdx
,
u
in
t32
_t
LowEightRowlaneIdx
,
in
dex
_t
HighEightRowLaneIdx
,
u
in
t32
_t
HighEightRowLaneIdx
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
5df713ef
...
@@ -369,7 +369,7 @@ struct WmmaGemm
...
@@ -369,7 +369,7 @@ struct WmmaGemm
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
4
D
=
MultiIndex
<
4
>
;
using
CIndex
3
D
=
MultiIndex
<
3
>
;
__host__
__device__
constexpr
WmmaGemm
()
__host__
__device__
constexpr
WmmaGemm
()
{
{
...
@@ -421,6 +421,46 @@ struct WmmaGemm
...
@@ -421,6 +421,46 @@ struct WmmaGemm
Sequence
<
5
>
{}));
Sequence
<
5
>
{}));
}
}
// Transposed WMMA Output C' = B' * A'
template
<
typename
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
const
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
)
{
const
auto
MBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
NBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
MBlockxRepeat
),
make_pass_through_transform
(
MWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{}),
make_pass_through_transform
(
NBlockxRepeat
),
make_pass_through_transform
(
NWave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
{
return
wmma_instr
.
num_acc_vgprs_per_wave
;
return
wmma_instr
.
num_acc_vgprs_per_wave
;
...
@@ -493,6 +533,14 @@ struct WmmaGemm
...
@@ -493,6 +533,14 @@ struct WmmaGemm
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
}
__device__
static
CIndex3D
GetBeginOfThreadBlk3D
()
{
index_t
n_offset
=
GetLaneIdUnderSubGroup
();
index_t
m_offset
=
GetSubGroupId
();
return
TransposeC
?
CIndex3D
{
n_offset
,
m_offset
,
I0
}
:
CIndex3D
{
m_offset
,
n_offset
,
I0
};
}
static
constexpr
auto
wmma
=
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment