Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
73e475d8
Commit
73e475d8
authored
Aug 07, 2023
by
aska-0096
Browse files
MQA implementation
parent
c3cba9c9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1998 additions
and
1 deletion
+1998
-1
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-0
example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
..._softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
+288
-0
example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc
...e_softmax_gemm/run_multi_query_attention_forward_wmma.inc
+339
-0
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
...device/impl/device_multi_query_attention_forward_wmma.hpp
+1247
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+122
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
73e475d8
...
@@ -10,6 +10,7 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
...
@@ -10,6 +10,7 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp
)
add_example_executable
(
example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp
)
add_example_executable
(
example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp
)
add_example_executable
(
example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp
)
add_example_executable
(
example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp
)
endif
()
endif
()
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
...
...
example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
0 → 100644
View file @
73e475d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
Acc0DataType
=
F32
;
using
Acc1DataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// clang-format off
// #define CK_MHA_USE_WAVE_1
// #define CK_MHA_USE_WAVE_2
// #define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck
::
tensor_operation
::
device
::
DeviceMultiQueryAttentionForward_Wmma
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceMultiQueryAttentionForward_Wmma
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck
::
tensor_operation
::
device
::
DeviceMultiQueryAttentionForward_Wmma
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceMultiQueryAttentionForward_Wmma
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck
::
tensor_operation
::
device
::
DeviceMultiQueryAttentionForward_Wmma
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceMultiQueryAttentionForward_Wmma
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceMultiQueryAttentionForward_Wmma
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
// Gemm 0
128
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceMultiQueryAttentionForward_Wmma
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
// Gemm 0
128
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
#endif
>
;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm_MQA
<
ADataType
,
B0DataType
,
Acc0DataType
,
Acc1DataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
Acc0DataType
,
ADataType
,
Acc0DataType
>
;
// Ref Gemm1: fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm_MQA
<
ADataType
,
B1DataType
,
CDataType
,
Acc1DataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
#include "run_multi_query_attention_forward_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc
0 → 100644
View file @
73e475d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
120
;
ck
::
index_t
N
=
1000
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
128
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
13
;
ck
::
index_t
KV_head
=
1
;
float
alpha
=
1
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
KV_head
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
KV_head
*
K
,
K
,
KV_head
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
KV_head
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
KV_head
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
KV_head
*
O
,
O
,
1
,
KV_head
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
KV_head
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_gs_ns_ks: "
<<
b0_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_gs_os_ns: "
<<
b1_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
break
;
case
4
:
// A, B0, B1 1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
5
:
// Rand: b1 b0; unit: a
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
6
:
// Rand: a b0 ; unit: B1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
7
:
// Rand: a b1 ; unit: b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
8
:
// Rand: a ; unit: b0 b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
9
:
// Rand: b0 ; unit: a b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
10
:
// Rand: b1 ; unit: a b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{
alpha
};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
float
best_perf
=
.0
;
float
best_time
=
.0
;
int
not_pass
=
0
;
std
::
string
best_kernel
=
""
;
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_conv_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_conv_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
G0
,
G1
,
alpha
,
input_permute
,
output_permute
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
// return 0;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
G0
*
G1
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
CDataType
)
*
M
*
O
)
*
G0
*
G1
+
(
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
)
*
G0
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
tflops
>
best_perf
)
{
best_perf
=
tflops
;
best_time
=
ave_time
*
1000
;
best_kernel
=
gemm
.
GetTypeString
();
}
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g0_g1_m_k
({
G0
,
G1
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g0_1_k_n
({
G0
,
1
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g0_1_n_o
({
G0
,
1
,
N
,
O
});
Tensor
<
Acc0DataType
>
acc0_g0_g1_m_n
({
G0
,
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g0_g1_m_n
({
G0
,
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g0_g1_m_o_host_result
({
G0
,
G1
,
M
,
O
});
// scratch object after gemm1
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g0_g1_m_k
(
idx
[
0
],
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g0_1_k_n
(
idx
[
0
],
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g0_1_n_o
(
idx
[
0
],
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
// gemm 0
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_g0_g1_m_k
,
b0_g0_1_k_n
,
acc0_g0_g1_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
N
);
acc0_g0_g1_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
2
],
idx
[
3
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
// softmax
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
acc0_g0_g1_m_n
,
a1_g0_g1_m_n
,
1
,
0
,
{
3
});
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// gemm1
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
a1_g0_g1_m_n
,
b1_g0_1_n_o
,
c_g0_g1_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// permute
c_gs_ms_os_host_result
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
c_g0_g1_m_o_host_result
(
idx
);
});
// default absolute error and relative error is 0.001
double
rtol
=
1
e
-
3
;
double
atol
=
1
e
-
3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B0DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B1DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1
e
-
2
;
atol
=
1
e
-
2
;
}
bool
this_run_verification
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results!"
,
rtol
,
atol
);
printf
(
"Verification: %s, Pass: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
,
this_run_verification
?
"YES"
:
"NO"
);
if
(
!
this_run_verification
)
{
not_pass
=
1
;
printf
(
"%d th MQA instance verification Failed
\n
"
,
i
.
value
);
}
}
});
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Best kernel: "
<<
best_kernel
<<
" , "
<<
best_perf
<<
" TFlops , "
<<
best_time
<<
" us"
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
return
not_pass
;
}
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
View file @
73e475d8
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType
// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType
// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType
// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType
//TODO: Current implementation consume more VGPR than expected.
//
TODO: Current implementation consume more VGPR than expected.
using
ADataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
QuantDataType
=
int8_t
;
using
QuantDataType
=
int8_t
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
0 → 100644
View file @
73e475d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Multi-Query Attention (MQA) kernel implementation
// Assume number of head of K,V is 1.
// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N]
// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O]
template
<
typename
DeviceOp
,
typename
GridwiseOp
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_multi_query_attention_wmma
(
const
ADataType
*
__restrict__
p_a_grid
,
const
B0DataType
*
__restrict__
p_b0_grid
,
const
B1DataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
,
index_t
M
,
// SequenceQ
index_t
N
,
// SequenceK
index_t
K
,
// HeadDim
index_t
O
,
// SequenceK
index_t
G0
,
// Batch
index_t
G1
,
// HeadNum
float
alpha
,
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
const
auto
q_head
=
G1
;
const
auto
kv_head
=
1
;
// Make Tensor Descriptors
constexpr
index_t
array_size
=
4
;
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_lengths
{
G0
,
q_head
,
M
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
q_head
*
K
,
K
,
q_head
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
q_head
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_lengths
{
G0
,
kv_head
,
N
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
kv_head
*
K
,
K
,
kv_head
*
K
,
1
}
// B0 layout [G0, N, 1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
kv_head
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, 1, N, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_lengths
{
G0
,
kv_head
,
O
,
N
};
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
kv_head
*
O
,
O
,
1
,
kv_head
*
O
}
// B1 layout [G0, N, 1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
kv_head
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, 1, N, O]
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
G0
,
q_head
,
M
,
O
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
=
output_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
q_head
*
O
,
O
,
q_head
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
q_head
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
const
auto
a_element_op
=
AElementwiseOperation
{};
const
auto
b0_element_op
=
B0ElementwiseOperation
{};
const
auto
acc0_element_op
=
AccElementwiseOperation
{
alpha
};
const
auto
b1_element_op
=
B1ElementwiseOperation
{};
const
auto
c_element_op
=
CElementwiseOperation
{};
// fail to reuse DeviceOp::MakeArgument() because of the __device__ function required.
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
a_grid_desc_g_m_k
=
DeviceOp
::
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc_g_l_k
=
DeviceOp
::
Transform
::
MakeB0GridDescriptor_G_N_K
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc_g_n_l
=
DeviceOp
::
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
compute_base_ptr_of_batch
=
typename
DeviceOp
::
ComputeBasePtrOfStridedBatch
{
a_grid_desc_g_m_k
,
b0_grid_desc_g_l_k
,
b1_grid_desc_g_n_l
,
c_grid_desc_g_m_n
};
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
const
auto
c0_matrix_mask
=
typename
DeviceOp
::
C0MatrixMask
{
b0_grid_desc_g_l_k
.
GetLength
(
Number
<
1
>
{})};
// clang-format on
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
/
G1
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
/
G1
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b0_grid
+
b0_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b0_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
M
;
ignore
=
N
;
ignore
=
K
;
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx1100__))
}
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimL
,
index_t
NumDimK
,
index_t
NumDimN
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc0DataType
,
typename
Acc1BiasDataType
,
typename
Acc1DataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
B0Spec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
LTilePerBlock
,
ck
::
index_t
L1
,
ck
::
index_t
MPerWmma
,
ck
::
index_t
LPerWmma
,
ck
::
index_t
NPerWmma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
LRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
B0BlockTransferThreadClusterLengths_K0_L_K1
,
typename
B0BlockTransferThreadClusterArrangeOrder
,
typename
B0BlockTransferSrcAccessOrder
,
ck
::
index_t
B0BlockTransferSrcVectorDim
,
ck
::
index_t
B0BlockTransferSrcScalarPerVector
,
ck
::
index_t
B0BlockTransferDstScalarPerVector_K1
,
bool
B0BlockLdsAddExtraL
,
typename
B1BlockTransferThreadClusterLengths_L0_N_L1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
ck
::
index_t
B1BlockTransferSrcVectorDim
,
ck
::
index_t
B1BlockTransferSrcScalarPerVector
,
ck
::
index_t
B1BlockTransferDstScalarPerVector_L1
,
bool
B1BlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceMultiQueryAttentionForward_Wmma
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimL
>
0
&&
NumDimK
>
0
&&
NumDimN
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
// TODO ANT: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static
constexpr
index_t
NumDimGemm0M
=
NumDimM
;
static
constexpr
index_t
NumDimGemm0N
=
NumDimL
;
static
constexpr
index_t
NumDimGemm0K
=
NumDimK
;
static
constexpr
index_t
NumDimGemm1M
=
NumDimM
;
static
constexpr
index_t
NumDimGemm1N
=
NumDimN
;
static
constexpr
index_t
NumDimGemm1K
=
NumDimL
;
using
DeviceOp
=
DeviceMultiQueryAttentionForward_Wmma
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
AEnableLds_auto
=
LWaves
==
1
?
false
:
true
;
static
constexpr
auto
B0EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B1EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B0EnableLds
=
B0EnableLds_auto
||
B0EnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B1EnableLds
=
B1EnableLds_auto
||
B1EnableLds_manu
||
(
NumPrefetch
>
1
);
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
>
,
GemmSpec
,
ASpec
,
B0Spec
,
B1Spec
,
CSpec
>
;
__host__
__device__
static
auto
MakeAGridDescriptor
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
if
constexpr
(
AEnableLds
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
AK1
>
{});
}
else
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
AK1
>
{});
}
}
__host__
__device__
static
auto
MakeB0GridDescriptor
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_strides_vec
)
{
if
constexpr
(
B0EnableLds
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
BK1
>
{});
}
else
{
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{},
Number
<
BK1
>
{});
}
}
__host__
__device__
static
auto
MakeB1GridDescriptor
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_strides_vec
)
{
if
constexpr
(
B1EnableLds
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
L1
>
{});
}
else
{
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{},
Number
<
L1
>
{});
}
}
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
B0GridDesc
=
decltype
(
MakeB0GridDescriptor
({},
{}));
using
B1GridDesc
=
decltype
(
MakeB1GridDescriptor
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
B0GridDesc_G_L_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_L
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
__host__
__device__
constexpr
static
auto
make_MaskOutPredicate
()
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
return
MaskOutUpperTrianglePredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
struct
ComputeBasePtrOfStridedBatch
{
__host__
__device__
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
B0GridDesc_G_L_K
&
b0_grid_desc_g_l_k
,
const
B1GridDesc_G_N_L
&
b1_grid_desc_g_n_l
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b0_grid_desc_g_l_k_
(
b0_grid_desc_g_l_k
),
b1_grid_desc_g_n_l_
(
b1_grid_desc_g_n_l
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB0BasePtr
(
index_t
g_idx
)
const
{
return
b0_grid_desc_g_l_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
return
b1_grid_desc_g_n_l_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
{
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
B0GridDesc_G_L_K
b0_grid_desc_g_l_k_
;
B1GridDesc_G_N_L
b1_grid_desc_g_n_l_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
};
// GridwiseOp
using
GridwiseOp
=
GridwiseBatchedGemmSoftmaxGemm_Wmma
<
// DataType Family
ADataType
,
B0DataType
,
Acc0DataType
,
B1DataType
,
Acc1DataType
,
CShuffleDataType
,
CDataType
,
// ElementwiseOp Family
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// InMemory Data Descriptor
AGridDesc
,
B0GridDesc
,
B1GridDesc
,
CGridDesc_M_N
,
// Tiling Family
MPerBlock
,
LPerBlock
,
KPerBlock
,
AK1
,
BK1
,
NPerBlock
,
LTilePerBlock
,
L1
,
MPerWmma
,
LPerWmma
,
NPerWmma
,
MRepeat
,
LRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
true
,
AEnableLds
,
ABlockLdsAddExtraM
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterArrangeOrder
,
B0BlockTransferSrcAccessOrder
,
B0BlockTransferSrcVectorDim
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferDstScalarPerVector_K1
,
true
,
B0EnableLds
,
B0BlockLdsAddExtraL
,
B1BlockTransferThreadClusterLengths_L0_N_L1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_L1
,
false
,
B1EnableLds
,
B1BlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
struct
RawArg
:
public
BaseArgument
{
RawArg
(
const
ADataType
*
p_a_grid
,
const
B0DataType
*
p_b0_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
O
,
index_t
G0
,
index_t
G1
,
float
alpha
,
bool
input_permute
,
bool
output_permute
)
:
p_a_grid_
{
p_a_grid
},
p_b0_grid_
{
p_b0_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
M_
{
M
},
N_
{
N
},
K_
{
K
},
O_
{
O
},
G0_
{
G0
},
G1_
{
G1
},
alpha_
{
alpha
},
input_permute_
{
input_permute
},
output_permute_
{
output_permute
}
{
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
B0DataType
*
p_b0_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
// Raw Problem Size
index_t
M_
;
index_t
N_
;
index_t
K_
;
index_t
O_
;
index_t
G0_
;
index_t
G1_
;
float
alpha_
;
bool
input_permute_
;
bool
output_permute_
;
};
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
B0DataType
*
p_b0
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
O
,
index_t
G0
,
index_t
G1
,
float
alpha
,
bool
input_permute
,
bool
output_permute
)
{
return
RawArg
{
p_a
,
p_b0
,
p_b1
,
p_c
,
M
,
N
,
K
,
O
,
G0
,
G1
,
alpha
,
input_permute
,
output_permute
};
}
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Acc0 Type err"
);
return
false
;
}
if
constexpr
(
!
(
is_same_v
<
Acc1DataType
,
float
>
||
is_same_v
<
Acc1DataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Acc1 Type err"
);
return
false
;
}
}
else
{
printf
(
"DeviceOp: Arch err"
);
return
false
;
}
constexpr
index_t
array_size
=
4
;
ck
::
index_t
G0
=
arg
.
G0_
;
ck
::
index_t
G1
=
arg
.
G1_
;
ck
::
index_t
M
=
arg
.
M_
;
ck
::
index_t
N
=
arg
.
N_
;
ck
::
index_t
K
=
arg
.
K_
;
ck
::
index_t
O
=
arg
.
O_
;
bool
input_permute
=
arg
.
input_permute_
;
bool
output_permute
=
arg
.
output_permute_
;
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
=
output_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
if
(
!
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_m_n
,
block_2_ctile_map
))
{
return
false
;
}
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
// unpadded
if
(
!
(
c_g
==
batch_count
))
{
printf
(
"DeviceOp: BatchCount err"
);
return
false
;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
M
;
const
auto
LzRaw
=
N
;
const
auto
KzRaw
=
K
;
const
auto
NzRaw
=
O
;
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b0_extent_lowest
=
B0BlockTransferSrcVectorDim
==
2
?
KzRaw
:
LzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
LzRaw
:
NzRaw
;
const
auto
c_extent_lowest
=
NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b0_extent_lowest
%
B0BlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
printf
(
"DeviceOp: Data Transfer Vector scalar err"
);
return
false
;
}
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lz_kz_strides_
{
b0_gs_ns_ks_strides
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ns_ks_strides
[
NumDimG
+
NumDimL
+
NumDimK
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_nz_lz_strides_
{
b1_gs_os_ns_strides
[
NumDimG
+
NumDimN
-
1
],
b1_gs_os_ns_strides
[
NumDimG
+
NumDimN
+
NumDimL
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_mz_nz_strides_
{
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
]};
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
a_mz_kz_strides_
[
1
]
:
a_mz_kz_strides_
[
0
];
const
auto
b0_stride_lowest
=
B0BlockTransferSrcVectorDim
==
2
?
b0_lz_kz_strides_
[
1
]
:
b0_lz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
b1_nz_lz_strides_
[
1
]
:
b1_nz_lz_strides_
[
0
];
const
auto
c_stride_lowest
=
c_mz_nz_strides_
[
1
];
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
printf
(
"DeviceOp: Data Vectorize transfer err"
);
return
false
;
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
RawArg
*>
(
p_arg
));
}
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
B0DataType
*
p_b0_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_ns_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
const
index_t
M01
,
const
index_t
N01
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b0_grid_
{
p_b0_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc
{
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc
{
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc
{
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_g_l_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_g_n_l_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
)},
a_element_op_
{
a_element_op
},
b0_element_op_
{
b0_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b0_grid_desc_g_l_k_
.
GetLength
(
I1
)},
raw_lengths_mz_lz_kz_nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b0_gs_ls_ks_lengths
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ls_ks_lengths
[
NumDimG
+
NumDimL
+
NumDimK
-
1
],
b1_gs_ns_ls_lengths
[
NumDimG
+
NumDimN
-
1
]},
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
b0_lz_kz_strides_
{
b0_gs_ls_ks_strides
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ls_ks_strides
[
NumDimG
+
NumDimL
+
NumDimK
-
1
]},
b1_nz_lz_strides_
{
b1_gs_ns_ls_strides
[
NumDimG
+
NumDimN
-
1
],
b1_gs_ns_ls_strides
[
NumDimG
+
NumDimN
+
NumDimL
-
1
]},
c_mz_nz_strides_
{
c_gs_ms_ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_ns_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_ptr_offset_of_batch_
{
a_grid_desc_g_m_k_
,
b0_grid_desc_g_l_k_
,
b1_grid_desc_g_n_l_
,
c_grid_desc_g_m_n_
}
{
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ls_lengths
;
ignore
=
acc0_biases_gs_ms_ls_strides
;
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
B0DataType
*
p_b0_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
// Tensor Descriptors
AGridDesc
a_grid_desc
;
B0GridDesc
b0_grid_desc
;
B1GridDesc
b1_grid_desc
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
B0GridDesc_G_L_K
b0_grid_desc_g_l_k_
;
B1GridDesc_G_N_L
b1_grid_desc_g_n_l_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
// Block to Tile mapping
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// ElementwiseOp
AElementwiseOperation
a_element_op_
;
B0ElementwiseOperation
b0_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
// Strides for the last M/N/K dimensions of A/B0/B1/C
// for sanity check of vector load/store
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
raw_lengths_mz_lz_kz_nz_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_mz_kz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lz_kz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_nz_lz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_mz_nz_strides_
;
index_t
batch_count_
;
// Batch Offset
ComputeBasePtrOfStridedBatch
compute_ptr_offset_of_batch_
;
};
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
RawArg
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
M0
=
math
::
integer_divide_ceil
(
arg
.
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
arg
.
O_
,
NPerBlock
);
const
index_t
grid_size
=
arg
.
G0_
*
arg
.
G1_
*
M0
*
N0
;
const
auto
K
=
arg
.
K_
;
// printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K));
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_multi_query_attention_wmma
<
DeviceOp
,
GridwiseOp
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b0_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
M_
,
arg
.
N_
,
arg
.
K_
,
arg
.
O_
,
arg
.
G0_
,
arg
.
G1_
,
arg
.
alpha_
,
arg
.
input_permute_
,
arg
.
output_permute_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
printf("DeviceOp: Acc0 Type err");
return false;
}
if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>))
{
printf("DeviceOp: Acc1 Type err");
return false;
}
}
else
{
printf("DeviceOp: Arch err");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
if(!(c_g == arg.batch_count_))
{
printf("DeviceOp: BatchCount err");
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0];
const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1];
const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2];
const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
const auto c_extent_lowest = NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
printf("DeviceOp: Data Transfer Vector scalar err");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
const auto c_stride_lowest = arg.c_mz_nz_strides_[1];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
printf("DeviceOp: Data Vectorize transfer err");
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const ADataType* p_a,
const B0DataType* p_b0,
const B1DataType* p_b1,
CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b0,
p_b1,
p_c,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ls_ks_lengths,
b0_gs_ls_ks_strides,
b1_gs_ns_ls_lengths,
b1_gs_ns_ls_strides,
c_gs_ms_ns_lengths,
c_gs_ms_ns_strides,
acc0_biases_gs_ms_ls_lengths,
acc0_biases_gs_ms_ls_strides,
acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_ns_strides,
1,
1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op};
}
#endif
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
void
*
p_c
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
{
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_strides
;
std
::
transform
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
end
(),
a_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
end
(),
a_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b0_gs_ls_ks_lengths
.
begin
(),
b0_gs_ls_ks_lengths
.
end
(),
b0_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b0_gs_ls_ks_strides
.
begin
(),
b0_gs_ls_ks_strides
.
end
(),
b0_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b1_gs_ns_ls_lengths
.
begin
(),
b1_gs_ns_ls_lengths
.
end
(),
b1_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b1_gs_ns_ls_strides
.
begin
(),
b1_gs_ns_ls_strides
.
end
(),
b1_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
c_gs_ms_ns_lengths
.
begin
(),
c_gs_ms_ns_lengths
.
end
(),
c_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
c_gs_ms_ns_strides
.
begin
(),
c_gs_ms_ns_strides
.
end
(),
c_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
B0DataType
*>
(
p_b0
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
p_acc0_biases
,
p_acc1_biases
,
a_lengths
,
a_strides
,
b0_lengths
,
b0_strides
,
b1_lengths
,
b1_strides
,
c_lengths
,
c_strides
,
acc0_biases_gs_ms_ls_lengths
,
acc0_biases_gs_ms_ls_strides
,
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
b0_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceMultiQueryAttentionForward_Wmma"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
LTilePerBlock
<<
", "
<<
L1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
B0Spec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"B0EnableLds: "
<<
B0EnableLds
<<
", "
<<
"B1EnableLds: "
<<
B1EnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
73e475d8
...
@@ -133,6 +133,128 @@ struct ReferenceBatchedGemm : public device::BaseOperator
...
@@ -133,6 +133,128 @@ struct ReferenceBatchedGemm : public device::BaseOperator
}
}
};
};
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
ReferenceBatchedGemm_MQA
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_g0_g1_m_k
,
const
Tensor
<
BDataType
>&
b_g0_1_k_n
,
Tensor
<
CDataType
>&
c_g0_g1_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
a_g0_g1_m_k_
{
a_g0_g1_m_k
},
b_g0_1_k_n_
{
b_g0_1_k_n
},
c_g0_g1_m_n_
{
c_g0_g1_m_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_g0_g1_m_k_
;
const
Tensor
<
BDataType
>&
b_g0_1_k_n_
;
Tensor
<
CDataType
>&
c_g0_g1_m_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceBatchedGemm_MQA
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
auto
f_g0g1mk_g01kn_g0g1mn
=
[
&
](
auto
g0
,
auto
g1
,
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_g0_g1_m_k_
.
mDesc
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
ADataType
v_a
;
BDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
arg
.
a_g0_g1_m_k_
(
g0
,
g1
,
m
,
k
));
arg
.
b_element_op_
(
v_b
,
arg
.
b_g0_1_k_n_
(
g0
,
0
,
k
,
n
));
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
AccDataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_g0_g1_m_n_
(
g0
,
g1
,
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
);
};
make_ParallelTensorFunctor
(
f_g0g1mk_g01kn_g0g1mn
,
arg
.
c_g0_g1_m_n_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c_g0_g1_m_n_
.
mDesc
.
GetLengths
()[
1
],
arg
.
c_g0_g1_m_n_
.
mDesc
.
GetLengths
()[
2
],
arg
.
c_g0_g1_m_n_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_g0_g1_m_k
,
const
Tensor
<
BDataType
>&
b_g0_1_k_n
,
Tensor
<
CDataType
>&
c_g0_g1_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
a_g0_g1_m_k
,
b_g0_1_k_n
,
c_g0_g1_m_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceBatchedGemm_MQA"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment