Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c9133170
"...community/pipeline_stable_diffusion_upscale_ldm3d.py" did not exist on "12a232efa99d7a8c33f54ae515c5a3d6fc5c8f34"
Commit
c9133170
authored
Mar 03, 2023
by
danyao12
Browse files
support fwd fp16&bf16
parent
45b52f13
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
54 additions
and
43 deletions
+54
-43
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+2
-4
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
...cale_softmax_gemm/batched_multihead_attention_forward.cpp
+14
-10
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
...cale_softmax_gemm/grouped_multihead_attention_forward.cpp
+14
-11
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+20
-18
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
c9133170
...
@@ -5,10 +5,8 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
...
@@ -5,10 +5,8 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_fp16 grouped_multihead_attention_forward_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_fp16 batched_multihead_attention_forward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt2 batched_multihead_attention_backward_pt2.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt2 batched_multihead_attention_backward_pt2.cpp
)
add_example_executable
(
example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward
_fp16
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
View file @
c9133170
...
@@ -33,17 +33,20 @@ template <ck::index_t... Is>
...
@@ -33,17 +33,20 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
DataType
=
BF16
;
using
B0DataType
=
F16
;
using
GemmDataType
=
BF16
;
using
B1DataType
=
F16
;
using
ADataType
=
DataType
;
using
B0DataType
=
DataType
;
using
B1DataType
=
DataType
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -81,6 +84,7 @@ using DeviceGemmInstance =
...
@@ -81,6 +84,7 @@ using DeviceGemmInstance =
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
...
@@ -139,7 +143,7 @@ using DeviceGemmInstance =
...
@@ -139,7 +143,7 @@ using DeviceGemmInstance =
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0:
fp16 in, fp32
out
// Ref Gemm0:
DataType in, AccDataType
out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
B0DataType
,
AccDataType
,
AccDataType
,
...
@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
Acc0ElementOp
>
;
// Ref Softmax:
fp32 in, fp16
out
// Ref Softmax:
AccDataType in, DataType
out
using
ReferenceSoftmaxInstance
=
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
// Ref Gemm1:
fp16 in, fp16
out
// Ref Gemm1:
DataType in, DataType
out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward
_fp16
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
View file @
c9133170
...
@@ -33,17 +33,20 @@ template <ck::index_t... Is>
...
@@ -33,17 +33,20 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
DataType
=
F16
;
using
B0DataType
=
F16
;
using
GemmDataType
=
F16
;
using
B1DataType
=
F16
;
using
ADataType
=
DataType
;
using
B0DataType
=
DataType
;
using
B1DataType
=
DataType
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -72,7 +75,6 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -72,7 +75,6 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -82,6 +84,7 @@ using DeviceGemmInstance =
...
@@ -82,6 +84,7 @@ using DeviceGemmInstance =
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
...
@@ -140,7 +143,7 @@ using DeviceGemmInstance =
...
@@ -140,7 +143,7 @@ using DeviceGemmInstance =
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0:
fp16 in, fp32
out
// Ref Gemm0:
DataType in, AccDataType
out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
B0DataType
,
AccDataType
,
AccDataType
,
...
@@ -149,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -149,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
Acc0ElementOp
>
;
// Ref Softmax:
fp32 in, fp16
out
// Ref Softmax:
AccDataType in, DataType
out
using
ReferenceSoftmaxInstance
=
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
// Ref Gemm1:
fp16 in, fp16
out
// Ref Gemm1:
DataType in, DataType
out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
c9133170
...
@@ -158,6 +158,7 @@ template <index_t NumDimG,
...
@@ -158,6 +158,7 @@ template <index_t NumDimG,
typename
BDataType
,
typename
BDataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
...
@@ -412,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -412,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
c9133170
...
@@ -148,6 +148,7 @@ template <index_t NumDimG,
...
@@ -148,6 +148,7 @@ template <index_t NumDimG,
typename
BDataType
,
typename
BDataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
...
@@ -423,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -423,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
c9133170
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
namespace
ck
{
namespace
ck
{
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
typename
FloatGemm
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC
,
...
@@ -126,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -126,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
Float
AB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
Float
Gemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
@@ -242,10 +243,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -242,10 +243,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
Float
AB
);
sizeof
(
Float
Gemm
);
const
index_t
gemm1_bytes_end
=
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
Float
AB
);
sizeof
(
Float
Gemm
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
sizeof
(
FloatGemmAcc
);
...
@@ -495,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -495,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
Float
AB
,
Float
Gemm
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -526,7 +527,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -526,7 +527,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
Float
AB
,
Float
Gemm
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -554,12 +555,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -554,12 +555,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
Float
AB
,
Float
Gemm
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -579,11 +581,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -579,11 +581,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
AB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
Float
Gemm
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
AB
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
Float
Gemm
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
@@ -658,7 +660,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -658,7 +660,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// A1 matrix blockwise copy
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
FloatGemmAcc
,
Float
AB
,
Float
Gemm
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -677,7 +679,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -677,7 +679,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
Float
AB
,
Float
Gemm
,
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcAccessOrder
,
...
@@ -698,12 +700,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -698,12 +700,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
AB
>
(
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
Gemm
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
// reuse LDS space for gemm0's b_block_buf
// reuse LDS space for gemm0's b_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
AB
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
Float
Gemm
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
...
@@ -716,11 +718,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -716,11 +718,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
Float
AB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
Float
Gemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
Float
AB
,
Float
Gemm
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
...
@@ -736,7 +738,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -736,7 +738,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Gemm1KPack
,
Gemm1KPack
,
true
,
// TransposeC
true
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
Float
AB
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
Gemm1KPack
*
XdlopsGemm
<
Float
Gemm
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
// BMmaKStride
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
...
@@ -1100,7 +1102,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1100,7 +1102,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// workaround compiler issue; see ck/ck.hpp
// workaround compiler issue; see ck/ck.hpp
if
constexpr
(
CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE
==
1
&&
if
constexpr
(
CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE
==
1
&&
is_same_v
<
Float
AB
,
bhalf_t
>
&&
MPerBlock
==
256
&&
NPerBlock
==
128
&&
(
is_same_v
<
Float
Gemm
,
bhalf_t
>
)
&&
MPerBlock
==
256
&&
NPerBlock
==
128
&&
Gemm1NPerBlock
==
128
)
Gemm1NPerBlock
==
128
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment