Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5df713ef
Commit
5df713ef
authored
Feb 11, 2023
by
aska-0096
Browse files
save progress
parent
a6b2f1c1
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
923 additions
and
767 deletions
+923
-767
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+31
-29
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+70
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+365
-443
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+405
-284
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+2
-2
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+49
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
5df713ef
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_
o
= Softmax(A_g_m_k * B0_g_k_
n
) * B1_g_
n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_
n
= Softmax(A_g_m_k * B0_g_k_
l
) * B1_g_
l_n
|-----------------|
Gemm0
|-------------------------------------|
...
...
@@ -39,7 +39,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
F32
;
using
Acc0DataType
=
F32
;
using
Acc1DataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
...
...
@@ -67,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_
Xdl
_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_
Wmma
_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
@@ -76,11 +77,12 @@ using DeviceGemmInstance =
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
AccDataType
,
Acc
1
DataType
,
CShuffleDataType
,
CDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
...
...
@@ -91,21 +93,21 @@ using DeviceGemmInstance =
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
128
,
// LPerBlock
4
,
// K0PerBlock
8
,
// K1
64
,
// NPerBlock
4
,
// L0PerBlock
8
,
// L1
16
,
// MPerWMMA
16
,
// LPerWMMA
16
,
// NPerWMMA
//Per repeat = wave_m = wave_num, wave_n = 1
1
,
// MRepeat
8
,
// LRepeat
4
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -113,44 +115,44 @@ using DeviceGemmInstance =
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
4
,
64
,
1
>
,
// B
0
BlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
4
,
64
,
1
>
,
// B1BlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
2
,
8
,
8
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
AccDataType
,
AccDataType
,
Acc
0
DataType
,
Acc
1
DataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
Acc
0
DataType
,
ADataType
,
Acc
0
DataType
>
;
// Ref Gemm1: fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
AccDataType
,
Acc
1
DataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
5df713ef
...
...
@@ -198,7 +198,7 @@ int run(int argc, char* argv[])
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc
0
DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
5df713ef
...
...
@@ -129,11 +129,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
using
Tuple5
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
// using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
// __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle(
// Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
// Tuple4 b_origin = CalculateBThreadOriginDataIndex())
// : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
...
...
@@ -303,8 +304,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
// AThreadCopy a_thread_copy_;
// BThreadCopy b_thread_copy_;
};
// block wise level pipe designed for inline asm
...
...
@@ -425,6 +428,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
blk_idx
[
I0
],
waveId_m
,
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
...
...
@@ -438,6 +460,30 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
...
...
@@ -483,6 +529,23 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
5df713ef
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
5df713ef
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
5df713ef
...
...
@@ -1313,8 +1313,8 @@ template <typename SrcData,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
in
dex
_t
LowEightRowlaneIdx
,
in
dex
_t
HighEightRowLaneIdx
,
u
in
t32
_t
LowEightRowlaneIdx
,
u
in
t32
_t
HighEightRowLaneIdx
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
5df713ef
...
...
@@ -369,7 +369,7 @@ struct WmmaGemm
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
4
D
=
MultiIndex
<
4
>
;
using
CIndex
3
D
=
MultiIndex
<
3
>
;
__host__
__device__
constexpr
WmmaGemm
()
{
...
...
@@ -421,6 +421,46 @@ struct WmmaGemm
Sequence
<
5
>
{}));
}
// Transposed WMMA Output C' = B' * A'
template
<
typename
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
const
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
)
{
const
auto
MBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
NBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
MBlockxRepeat
),
make_pass_through_transform
(
MWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{}),
make_pass_through_transform
(
NBlockxRepeat
),
make_pass_through_transform
(
NWave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
return
wmma_instr
.
num_acc_vgprs_per_wave
;
...
...
@@ -493,6 +533,14 @@ struct WmmaGemm
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
__device__
static
CIndex3D
GetBeginOfThreadBlk3D
()
{
index_t
n_offset
=
GetLaneIdUnderSubGroup
();
index_t
m_offset
=
GetSubGroupId
();
return
TransposeC
?
CIndex3D
{
n_offset
,
m_offset
,
I0
}
:
CIndex3D
{
m_offset
,
n_offset
,
I0
};
}
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment