Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d9579dc8
You need to sign in or sign up before continuing.
Commit
d9579dc8
authored
Mar 07, 2023
by
fsx950223
Browse files
merge updates
parents
98ccee74
36ca02f3
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
386 additions
and
492 deletions
+386
-492
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+5
-7
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1.cpp
...softmax_gemm/batched_multihead_attention_backward_pt1.cpp
+44
-14
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt2.cpp
...softmax_gemm/batched_multihead_attention_backward_pt2.cpp
+27
-5
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
...cale_softmax_gemm/batched_multihead_attention_forward.cpp
+13
-9
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_fp16.cpp
...softmax_gemm/batched_multihead_attention_forward_fp16.cpp
+0
-170
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+37
-11
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
...cale_softmax_gemm/grouped_multihead_attention_forward.cpp
+14
-10
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+0
-171
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+2
-3
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+5
-7
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+43
-41
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+48
-37
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+4
-4
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+135
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
d9579dc8
...
...
@@ -5,14 +5,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward
_fp16
grouped_multihead_attention_forward
_fp16
.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward
_fp16
batched_multihead_attention_forward
_fp16
.cpp
)
add_example_executable
(
example_
group
ed_multihead_attention_
for
ward_
bf16 group
ed_multihead_attention_
for
ward_
bf16
.cpp
)
add_example_executable
(
example_batched_multihead_attention_
for
ward_
bf16
batched_multihead_attention_
for
ward_
bf16
.cpp
)
add_example_executable
(
example_batched_multihead_attention_
backward_fp16
batched_multihead_attention_
backward_fp16
.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp
)
add_example_executable
(
example_
batch
ed_multihead_attention_
back
ward_
pt1 batch
ed_multihead_attention_
back
ward_
pt1
.cpp
)
add_example_executable
(
example_batched_multihead_attention_
back
ward_
pt2
batched_multihead_attention_
back
ward_
pt2
.cpp
)
add_example_executable
(
example_batched_multihead_attention_
train
batched_multihead_attention_
train
.cpp
)
add_example_executable
(
example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1
_fp16
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1.cpp
View file @
d9579dc8
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK
0
#define USING_MASK
1
#define USING_HD32 0
#include <iostream>
...
...
@@ -49,9 +49,10 @@ Kernel outputs:
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
...
@@ -59,7 +60,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
DataType
=
BF16
;
using
GemmDataType
=
BF16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
...
...
@@ -101,6 +103,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimO
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -169,6 +172,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimO
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -340,16 +344,21 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
// 512
ck
::
index_t
N
=
512
;
// 512
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G1
=
6
;
// 16
ck
::
index_t
M
=
1536
;
// 512
ck
::
index_t
N
=
1536
;
// 512
#if USING_HD32
ck
::
index_t
K
=
32
;
// K/O<=32
ck
::
index_t
O
=
32
;
#else
ck
::
index_t
K
=
64
;
// 32<K/O<=64
ck
::
index_t
O
=
64
;
#endif
ck
::
index_t
G0
=
1
;
// 54
ck
::
index_t
G1
=
1
;
// 16
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
true
;
//
false;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
...
...
@@ -386,6 +395,8 @@ int run(int argc, char* argv[])
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
p_drop
=
std
::
stoi
(
argv
[
13
]);
}
else
{
...
...
@@ -398,6 +409,22 @@ int run(int argc, char* argv[])
exit
(
0
);
}
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"time_kernel: "
<<
time_kernel
<<
std
::
endl
;
std
::
cout
<<
"M: "
<<
M
<<
std
::
endl
;
std
::
cout
<<
"N: "
<<
N
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"p_drop: "
<<
p_drop
<<
std
::
endl
;
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
...
...
@@ -747,9 +774,12 @@ int run(int argc, char* argv[])
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_o
(
idx_gmo
);
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
#if PRINT_HOST
{
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_
fp16
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_
pt2
.cpp
View file @
d9579dc8
...
...
@@ -50,9 +50,10 @@ Kernel outputs:
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
...
@@ -387,6 +388,8 @@ int run(int argc, char* argv[])
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
p_drop
=
std
::
stoi
(
argv
[
13
]);
}
else
{
...
...
@@ -399,6 +402,22 @@ int run(int argc, char* argv[])
exit
(
0
);
}
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"time_kernel: "
<<
time_kernel
<<
std
::
endl
;
std
::
cout
<<
"M: "
<<
M
<<
std
::
endl
;
std
::
cout
<<
"N: "
<<
N
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"p_drop: "
<<
p_drop
<<
std
::
endl
;
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
...
...
@@ -748,9 +767,12 @@ int run(int argc, char* argv[])
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_o
(
idx_gmo
);
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
#if PRINT_HOST
{
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward
_bf16
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
View file @
d9579dc8
...
...
@@ -32,18 +32,21 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
BF16
;
using
B0DataType
=
BF16
;
using
B1DataType
=
BF16
;
using
DataType
=
BF16
;
using
GemmDataType
=
BF16
;
using
ADataType
=
DataType
;
using
B0DataType
=
DataType
;
using
B1DataType
=
DataType
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
BF16
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
...
...
@@ -81,6 +84,7 @@ using DeviceGemmInstance =
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -99,7 +103,7 @@ using DeviceGemmInstance =
TensorSpecC
,
1
,
256
,
256
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
...
...
@@ -109,7 +113,7 @@ using DeviceGemmInstance =
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
2
,
// MXdlPerWave
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
...
...
@@ -139,7 +143,7 @@ using DeviceGemmInstance =
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0:
bf16 in, fp32
out
// Ref Gemm0:
DataType in, AccDataType
out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
AccDataType
,
...
...
@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax:
fp32 in, bf16
out
// Ref Softmax:
AccDataType in, DataType
out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
// Ref Gemm1:
bf16 in, bf16
out
// Ref Gemm1:
DataType in, DataType
out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_fp16.cpp
deleted
100644 → 0
View file @
98ccee74
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
256
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
2
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
// Ref Gemm1: fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
ADataType
,
ADataType
>
;
#include "run_batched_multihead_attention_forward.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train
_fp16
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
d9579dc8
...
...
@@ -59,9 +59,10 @@ Kernel outputs:
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
...
@@ -69,7 +70,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
DataType
=
BF16
;
using
GemmDataType
=
BF16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
...
...
@@ -108,6 +110,7 @@ using DeviceGemmInstanceFWD =
DataType
,
DataType
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -180,6 +183,7 @@ using DeviceGemmInstanceBWD =
NumDimK
,
NumDimO
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -248,6 +252,7 @@ using DeviceGemmInstanceBWD =
NumDimK
,
NumDimO
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -419,8 +424,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
200
;
// 512
ck
::
index_t
N
=
200
;
// 512
ck
::
index_t
M
=
129
;
// 512
ck
::
index_t
N
=
129
;
// 512
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
G0
=
4
;
// 54
...
...
@@ -428,8 +433,8 @@ int run(int argc, char* argv[])
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
fals
e
;
bool
output_permute
=
fals
e
;
bool
input_permute
=
tru
e
;
bool
output_permute
=
tru
e
;
float
p_drop
=
0.0
;
float
p_dropout
=
1
-
p_drop
;
...
...
@@ -465,6 +470,8 @@ int run(int argc, char* argv[])
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
p_drop
=
std
::
stoi
(
argv
[
13
]);
}
else
{
...
...
@@ -477,6 +484,22 @@ int run(int argc, char* argv[])
exit
(
0
);
}
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"time_kernel: "
<<
time_kernel
<<
std
::
endl
;
std
::
cout
<<
"M: "
<<
M
<<
std
::
endl
;
std
::
cout
<<
"N: "
<<
N
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"p_drop: "
<<
p_drop
<<
std
::
endl
;
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
...
...
@@ -959,9 +982,12 @@ int run(int argc, char* argv[])
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_o
(
idx_gmo
);
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
#if PRINT_HOST
{
...
...
@@ -1058,7 +1084,7 @@ int run(int argc, char* argv[])
double
atol
=
1e-3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
if
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
||
std
::
is_same_v
<
GemmDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1e-2
;
atol
=
1e-2
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward
_bf16
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
View file @
d9579dc8
...
...
@@ -32,18 +32,21 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
BF16
;
using
B0DataType
=
BF16
;
using
B1DataType
=
BF16
;
using
DataType
=
F16
;
using
GemmDataType
=
F16
;
using
ADataType
=
DataType
;
using
B0DataType
=
DataType
;
using
B1DataType
=
DataType
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
BF16
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
...
...
@@ -81,6 +84,7 @@ using DeviceGemmInstance =
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -102,8 +106,8 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
...
...
@@ -111,7 +115,7 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -139,7 +143,7 @@ using DeviceGemmInstance =
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0:
bf16 in, fp32
out
// Ref Gemm0:
DataType in, AccDataType
out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
AccDataType
,
...
...
@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax:
fp32 in, bf16
out
// Ref Softmax:
AccDataType in, DataType
out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
// Ref Gemm1:
bf16 in, bf16
out
// Ref Gemm1:
DataType in, DataType
out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
deleted
100644 → 0
View file @
98ccee74
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
ADataType
,
AccDataType
>
;
// Ref Gemm1: fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
ADataType
,
ADataType
>
;
#include "run_grouped_multihead_attention_forward.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
d9579dc8
...
...
@@ -97,7 +97,7 @@ int run(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
out
put_permute
in
put_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
...
...
@@ -360,8 +360,7 @@ int run(int argc, char* argv[])
double
atol
=
1
e
-
3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B0DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B1DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
if
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
||
std
::
is_same_v
<
GemmDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1
e
-
2
;
atol
=
1
e
-
2
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
d9579dc8
...
...
@@ -10,8 +10,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
float
p_drop
=
0.1
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
...
...
@@ -84,8 +83,8 @@ int run(int argc, char* argv[])
int
M
=
128
*
(
rand
()
%
8
+
1
);
int
N
=
128
*
(
rand
()
%
8
+
1
);
int
K
=
128
;
int
O
=
128
;
int
K
=
64
;
int
O
=
64
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
...
...
@@ -117,7 +116,7 @@ int run(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
out
put_permute
in
put_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
...
...
@@ -427,8 +426,7 @@ int run(int argc, char* argv[])
double
atol
=
1
e
-
3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B0DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B1DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
if
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
||
std
::
is_same_v
<
GemmDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1
e
-
2
;
atol
=
1
e
-
2
;
...
...
include/ck/ck.hpp
View file @
d9579dc8
...
...
@@ -118,7 +118,7 @@
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
0
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
// experimental feature: in-regsiter sub-dword transpose
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
d9579dc8
...
...
@@ -173,6 +173,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
...
...
@@ -598,9 +599,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
DataType
,
// TODO: distinguish A/B datatype
LSE
DataType
,
Gemm
DataType
,
GemmAccDataType
,
CShuffleDataType
,
LSEDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
d9579dc8
...
...
@@ -158,6 +158,7 @@ template <index_t NumDimG,
typename
BDataType
,
typename
B1DataType
,
typename
CDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
...
...
@@ -412,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmDataType
,
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
d9579dc8
...
...
@@ -148,6 +148,7 @@ template <index_t NumDimG,
typename
BDataType
,
typename
B1DataType
,
typename
CDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
...
...
@@ -423,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmDataType
,
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
d9579dc8
...
...
@@ -21,6 +21,7 @@
namespace
ck
{
template
<
typename
DataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatLSE
,
...
...
@@ -121,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
...
@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
Gemm
DataType
,
GridDesc_K0_M_K1
,
decltype
(
q_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
Gemm
DataType
,
GridDesc_K0_N_K1
,
decltype
(
k_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
Gemm
DataType
,
GridDesc_K0_N_K1
,
decltype
(
v_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
Gemm
DataType
,
GridDesc_K0_M_K1
,
decltype
(
ygrad_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -506,13 +507,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -587,7 +589,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
Gemm
DataType
,
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -610,7 +612,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmMWave
=
Gemm0MWaves
;
static
constexpr
index_t
GemmNWave
=
Gemm0NWaves
;
...
...
@@ -676,8 +678,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
auto
b_thread_desc_k0_n_k1
=
MakeBThreadDesc_K0_N_K1
();
using
BBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
ThreadwiseTensorSliceTransfer_v2
<
Gemm
DataType
,
Gemm
DataType
,
decltype
(
b_block_desc_n0_n1_n2_k0_k1_k2_k3
),
decltype
(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3
,
...
...
@@ -692,7 +694,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_thread_desc_k0_n_k1
),
...
...
@@ -733,12 +735,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMLoop
=
Free1_M
/
Sum_M
;
static
constexpr
index_t
GemmMPack
=
math
::
max
(
A_M1
,
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
A_M1
,
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
B_M3
=
GemmMPack
;
// 8
static
constexpr
index_t
B_M2
=
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
GemmMPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_M1
=
Sum_M
/
B_M2
/
B_M3
;
// 4
static
constexpr
index_t
B_M0
=
GemmMLoop
;
// 2
XdlopsGemm
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
,
GemmMPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_M1
=
Sum_M
/
B_M2
/
B_M3
;
// 4
static
constexpr
index_t
B_M0
=
GemmMLoop
;
// 2
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_N0_M1_N1_M2_N2
()
{
...
...
@@ -875,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
Gemm
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
ElementwiseOp
,
...
...
@@ -968,8 +970,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
auto
b_thread_desc_m0_o_m1
=
MakeBThreadDesc_M0_O_M1
();
using
BBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
ThreadwiseTensorSliceTransfer_v2
<
Gemm
DataType
,
Gemm
DataType
,
decltype
(
b_block_desc_o0_o1_o2_m0_m1_m2_m3
),
decltype
(
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
),
BThreadSlice_O0_O1_O2_M0_M1_M2_M3
,
...
...
@@ -985,7 +987,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
b_thread_desc_m0_o_m1
),
...
...
@@ -1001,7 +1003,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
Gemm2Params_N_O_M
::
GemmMPack
,
true
,
// TransposeC
Gemm2Params_N_O_M
::
GemmMPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
Gemm2Params_N_O_M
::
GemmMPack
,
false
>
{}
XdlopsGemm
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
,
Gemm2Params_N_O_M
::
GemmMPack
,
false
>
{}
.
K0PerXdlops
,
Gemm2Params_N_O_M
::
GemmMPack
>
;
...
...
@@ -1092,7 +1094,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
FloatGemmAcc
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
...
...
@@ -1165,7 +1167,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
auto
p_slash_sgrad_block_desc_m0_n_m1
=
GetA2BlockDescriptor_M0_N_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
Gemm
DataType
)
>
{};
static
constexpr
auto
q_block_space_size_aligned
=
math
::
integer_least_multiple
(
q_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -1193,7 +1195,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
auto
reduction_space_offset
=
(
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
DataType
)
/
sizeof
(
FloatGemmAcc
);
sizeof
(
Gemm
DataType
)
/
sizeof
(
FloatGemmAcc
);
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1206,14 +1208,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
{
const
index_t
k_bytes_end
=
(
SharedMemTrait
::
k_block_space_offset
+
SharedMemTrait
::
k_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
v_bytes_end
=
(
SharedMemTrait
::
v_block_space_offset
+
SharedMemTrait
::
v_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
p_slash_sgrad_bytes_end
=
(
SharedMemTrait
::
p_slash_sgrad_block_space_offset
+
SharedMemTrait
::
p_slash_sgrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
...
...
@@ -1263,8 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const
float
p_drop
,
ck
::
philox
&
ph
)
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
65535.0
));
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
...
...
@@ -1315,19 +1317,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// LDS allocation for Q / K / V / dY
auto
q_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
q_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
q_block_space_offset
,
GemmBlockwiseCopy
::
q_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
k_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
k_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
k_block_space_offset
,
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
v_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
v_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
v_block_space_offset
,
GemmBlockwiseCopy
::
v_block_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
ygrad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
GemmBlockwiseCopy
::
ygrad_block_desc_k0_m_k1
.
GetElementSpaceSize
());
// Q matrix blockwise copy
...
...
@@ -1394,10 +1396,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Gemm1: VGPR allocation for A and B
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Gemm
DataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm1_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
gemm1_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Gemm
DataType
>
(
Gemm1
::
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
.
GetElementSpaceSize
());
// dQ: transform input and output tensor descriptors
...
...
@@ -1589,10 +1591,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_slash_sgrad_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_slash_sgrad_block_space_offset
,
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
gemm2_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
gemm2_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Gemm
DataType
>
(
Gemm2
::
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
.
GetElementSpaceSize
());
// dV: transform input and output tensor descriptors
...
...
@@ -1722,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
...
...
@@ -1735,8 +1737,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for ygrad
auto
ygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
decltype
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
.
GetLengths
()),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
d9579dc8
...
...
@@ -908,7 +908,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
FloatGemmAcc
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
d9579dc8
...
...
@@ -21,6 +21,7 @@
namespace
ck
{
template
<
typename
FloatAB
,
typename
FloatGemm
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
...
...
@@ -126,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
Float
AB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
Float
Gemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
...
@@ -242,10 +243,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
Float
AB
);
sizeof
(
Float
Gemm
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
Float
AB
);
sizeof
(
Float
Gemm
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
...
...
@@ -273,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
//
if(Gemm1N != K)
//
{
//
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
//
return false;
//
}
if
(
Gemm1N
!=
K
)
{
std
::
cout
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
...
...
@@ -495,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
Float
AB
,
Float
Gemm
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -526,7 +527,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
Float
AB
,
Float
Gemm
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -554,12 +555,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
Float
AB
,
Float
Gemm
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -579,11 +581,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
AB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
Float
Gemm
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
AB
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
Float
Gemm
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
...
@@ -658,7 +660,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
Float
AB
,
Float
Gemm
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -677,7 +679,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
FloatAB
,
Float
AB
,
Float
Gemm
,
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
...
...
@@ -698,12 +700,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
AB
>
(
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
Gemm
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
// reuse LDS space for gemm0's b_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
AB
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
Float
Gemm
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
...
...
@@ -716,11 +718,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
Float
AB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
Float
Gemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
Float
AB
,
Float
Gemm
,
FloatGemmAcc
,
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
...
...
@@ -736,7 +738,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Gemm1KPack
,
true
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
Float
AB
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
Gemm1KPack
*
XdlopsGemm
<
Float
Gemm
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
...
...
@@ -850,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
I1
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
...
...
@@ -881,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
I1
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
...
...
@@ -1004,25 +1006,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// save z to global
if
(
p_z_grid
)
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
false
>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
false
,
decltype
(
n0
),
decltype
(
i
)>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
make_multi_index
(
0
,
0
,
0
,
-
(
n0
.
value
),
0
,
0
,
0
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
else
{
// ignore = z_grid_buf;
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
false
>(
acc_thread_buf
,
ph
);
...
...
@@ -1100,7 +1111,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// workaround compiler issue; see ck/ck.hpp
if
constexpr
(
CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE
==
1
&&
is_same_v
<
Float
AB
,
bhalf_t
>
&&
MPerBlock
==
256
&&
NPerBlock
==
128
&&
(
is_same_v
<
Float
Gemm
,
bhalf_t
>
)
&&
MPerBlock
==
256
&&
NPerBlock
==
128
&&
Gemm1NPerBlock
==
128
)
{
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
d9579dc8
...
...
@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x
7fffffff
;
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x
80000000
;
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
...
...
@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
7fffffff
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
80000000
;
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
...
...
@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
7fffffff
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
80000000
;
amd_buffer_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
...
...
@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
7fffffff
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
80000000
;
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
...
...
include/ck/utility/generic_memory_space_atomic.hpp
View file @
d9579dc8
...
...
@@ -71,6 +71,141 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
return
vy
.
template
AsType
<
double2_t
>()[
I0
];
}
inline
__host__
__device__
half2_t
add_fp16x2_t
(
const
half2_t
&
a
,
const
half2_t
&
b
)
{
half2_t
rtn
;
rtn
[
0
]
=
a
[
0
]
+
b
[
0
];
rtn
[
1
]
=
a
[
1
]
+
b
[
1
];
return
rtn
;
}
union
U32FP162_ADDR
{
uint32_t
*
u32_a
;
half2_t
*
fp162_a
;
};
union
U32FP162
{
uint32_t
u32
;
half2_t
fp162
;
};
template
<
>
__device__
half2_t
atomic_add
<
half2_t
>
(
half2_t
*
p_dst
,
const
half2_t
&
x
)
{
U32FP162_ADDR
dword_addr
;
U32FP162
cur_v
;
U32FP162
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
fp162_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
fp162
=
add_fp16x2_t
(
cur_v
.
fp162
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
return
x
;
}
// template <>
// __device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// union U16BF16 {
// uint16_t u16;
// bhalf_t bf16;
// };
// inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){
// U16BF16 xa {.bf16 = a};
// U16BF16 xb {.bf16 = b};
// U16BF16 xr;
// xr.u16 = xa.u16 + xb.u16;
// return xr.bf16;
// }
inline
__host__
__device__
bhalf_t
add_bf16_t
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
)
{
return
type_convert
<
bhalf_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
}
inline
__host__
__device__
bhalf2_t
add_bf16x2_t
(
const
bhalf2_t
&
a
,
const
bhalf2_t
&
b
)
{
bhalf2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
union
U32BF162_ADDR
{
uint32_t
*
u32_a
;
bhalf2_t
*
bf162_a
;
};
union
U32BF162
{
uint32_t
u32
;
bhalf2_t
bf162
;
};
template
<
>
__device__
bhalf2_t
atomic_add
<
bhalf2_t
>
(
bhalf2_t
*
p_dst
,
const
bhalf2_t
&
x
)
{
U32BF162_ADDR
dword_addr
;
U32BF162
cur_v
;
U32BF162
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
bf162_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
bf162
=
add_bf16x2_t
(
cur_v
.
bf162
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
return
x
;
}
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment