Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
883c060a
Commit
883c060a
authored
Sep 04, 2023
by
guangzlu
Browse files
bwd qloop v1 passed
parent
72498367
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
272 additions
and
501 deletions
+272
-501
client_example/08_fused_attention/fused_attention_bwd.cpp
client_example/08_fused_attention/fused_attention_bwd.cpp
+106
-323
library/include/ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
...y/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
+46
-50
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...ched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+60
-64
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...atched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+60
-64
No files found.
client_example/08_fused_attention/fused_attention_bwd.cpp
View file @
883c060a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <vector>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp"
...
...
@@ -14,13 +16,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -51,11 +46,9 @@ static constexpr ck::index_t NumDimM = 1;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
;
...
...
@@ -64,12 +57,6 @@ static constexpr auto MaskingSpec =
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
#endif
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
...
...
@@ -86,70 +73,10 @@ struct SimpleDeviceMem
void
*
p_mem_
;
};
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
// P = Softmax(S)
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_drop_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
int
init_method
=
1
;
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
...
...
@@ -164,53 +91,12 @@ int main(int argc, char* argv[])
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
p_drop
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"time_kernel: "
<<
time_kernel
<<
std
::
endl
;
std
::
cout
<<
"M: "
<<
M
<<
std
::
endl
;
std
::
cout
<<
"N: "
<<
N
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
...
...
@@ -264,120 +150,19 @@ int main(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward
<
2
,
SimpleDeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
M
*
K
);
SimpleDeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
N
*
K
);
SimpleDeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
G0
*
G1
*
M
*
N
);
SimpleDeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
O
*
N
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
M
*
O
);
SimpleDeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
G0
*
G1
*
M
);
SimpleDeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
G0
*
G1
*
M
*
K
);
SimpleDeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
G0
*
G1
*
N
*
K
);
SimpleDeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
G0
*
G1
*
O
*
N
);
SimpleDeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
M
*
O
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
...
...
@@ -388,11 +173,11 @@ int main(int argc, char* argv[])
LSEDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
A
ElementOp
,
B0
ElementOp
,
Acc0ElementOp
,
B1
ElementOp
,
C
ElementOp
,
QKV
ElementOp
,
QKV
ElementOp
,
Scale
,
QKV
ElementOp
,
Y
ElementOp
,
MaskingSpec
>
;
// get device op instances
...
...
@@ -409,13 +194,12 @@ int main(int argc, char* argv[])
// profile device op instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
q_device_buf
.
GetDeviceBuffer
(),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
q_device_buf
.
GetDeviceBuffer
(),
k_device_buf
.
GetDeviceBuffer
(),
nullptr
,
// set to nullptr
v_device_buf
.
GetDeviceBuffer
(),
...
...
@@ -450,13 +234,11 @@ int main(int argc, char* argv[])
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
...
...
@@ -496,7 +278,8 @@ int main(int argc, char* argv[])
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
q_device_buf
.
GetDeviceBuffer
(),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
q_device_buf
.
GetDeviceBuffer
(),
k_device_buf
.
GetDeviceBuffer
(),
nullptr
,
// set to nullptr
v_device_buf
.
GetDeviceBuffer
(),
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
View file @
883c060a
...
...
@@ -18,8 +18,9 @@ namespace device {
namespace
instance
{
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
...
...
@@ -34,12 +35,11 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
);
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
...
...
@@ -55,12 +55,11 @@ void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instance
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
...
...
@@ -76,12 +75,11 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
);
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
...
...
@@ -97,8 +95,7 @@ void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instan
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
typename
InputDataType
,
typename
OutputDataType
,
...
...
@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory<
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
{
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
883c060a
...
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
//
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
...
@@ -50,8 +51,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimO
,
MaskingSpecialization
MaskingSpec
>
using
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
using
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
...
...
@@ -65,8 +65,8 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
>
;
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
...
...
@@ -82,11 +82,9 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
)
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
2
,
1
,
...
...
@@ -97,8 +95,8 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
}
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
...
...
@@ -114,11 +112,9 @@ void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instan
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
2
,
1
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
883c060a
...
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
//
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
...
@@ -50,8 +51,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimO
,
MaskingSpecialization
MaskingSpec
>
using
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
using
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
...
...
@@ -65,8 +65,8 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
>
;
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
...
...
@@ -82,11 +82,9 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
)
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
2
,
1
,
...
...
@@ -97,8 +95,8 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
}
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
...
...
@@ -114,11 +112,9 @@ void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instance
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
2
,
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment