Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
883c060a
Commit
883c060a
authored
Sep 04, 2023
by
guangzlu
Browse files
bwd qloop v1 passed
parent
72498367
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
272 additions
and
501 deletions
+272
-501
client_example/08_fused_attention/fused_attention_bwd.cpp
client_example/08_fused_attention/fused_attention_bwd.cpp
+106
-323
library/include/ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
...y/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
+46
-50
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...ched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+60
-64
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...atched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+60
-64
No files found.
client_example/08_fused_attention/fused_attention_bwd.cpp
View file @
883c060a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <vector>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp"
...
@@ -14,13 +16,6 @@
...
@@ -14,13 +16,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -51,11 +46,9 @@ static constexpr ck::index_t NumDimM = 1;
...
@@ -51,11 +46,9 @@ static constexpr ck::index_t NumDimM = 1;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
;
...
@@ -64,12 +57,6 @@ static constexpr auto MaskingSpec =
...
@@ -64,12 +57,6 @@ static constexpr auto MaskingSpec =
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
#endif
#endif
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
struct
SimpleDeviceMem
struct
SimpleDeviceMem
{
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
()
=
delete
;
...
@@ -86,70 +73,10 @@ struct SimpleDeviceMem
...
@@ -86,70 +73,10 @@ struct SimpleDeviceMem
void
*
p_mem_
;
void
*
p_mem_
;
};
};
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
// P = Softmax(S)
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_drop_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
int
init_method
=
1
;
ck
::
index_t
M
=
512
;
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
...
@@ -164,53 +91,12 @@ int main(int argc, char* argv[])
...
@@ -164,53 +91,12 @@ int main(int argc, char* argv[])
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
p_drop
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"time_kernel: "
<<
time_kernel
<<
std
::
endl
;
std
::
cout
<<
"M: "
<<
M
<<
std
::
endl
;
std
::
cout
<<
"M: "
<<
M
<<
std
::
endl
;
std
::
cout
<<
"N: "
<<
N
<<
std
::
endl
;
std
::
cout
<<
"N: "
<<
N
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
...
@@ -264,137 +150,36 @@ int main(int argc, char* argv[])
...
@@ -264,137 +150,36 @@ int main(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
SimpleDeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
M
*
K
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
SimpleDeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
N
*
K
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
SimpleDeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
G0
*
G1
*
M
*
N
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
SimpleDeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
O
*
N
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
M
*
O
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
SimpleDeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
G0
*
G1
*
M
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
SimpleDeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
G0
*
G1
*
M
*
K
);
SimpleDeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
G0
*
G1
*
N
*
K
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
SimpleDeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
G0
*
G1
*
O
*
N
);
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
SimpleDeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
G0
*
G1
*
M
*
O
);
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
using
DeviceOp
=
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward
<
2
,
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
1
,
1
,
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
1
,
switch
(
init_method
)
1
,
{
InputDataType
,
case
0
:
break
;
OutputDataType
,
case
1
:
ZDataType
,
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
LSEDataType
,
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ck
::
Tuple
<>
,
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ck
::
Tuple
<>
,
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
QKVElementOp
,
break
;
QKVElementOp
,
case
2
:
Scale
,
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
QKVElementOp
,
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
YElementOp
,
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
MaskingSpec
>
;
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
InputDataType
,
OutputDataType
,
ZDataType
,
LSEDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
MaskingSpec
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
DeviceOp
>::
GetInstances
();
...
@@ -409,54 +194,51 @@ int main(int argc, char* argv[])
...
@@ -409,54 +194,51 @@ int main(int argc, char* argv[])
// profile device op instances
// profile device op instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
q_device_buf
.
GetDeviceBuffer
(),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
k
_device_buf
.
GetDeviceBuffer
(),
q
_device_buf
.
GetDeviceBuffer
(),
nullptr
,
// set to nullptr
k_device_buf
.
GetDeviceBuffer
(),
v_device_buf
.
GetDeviceBuffer
(),
nullptr
,
// set to nullptr
y
_device_buf
.
GetDeviceBuffer
(),
v
_device_buf
.
GetDeviceBuffer
(),
lse
_device_buf
.
GetDeviceBuffer
(),
y
_device_buf
.
GetDeviceBuffer
(),
ygrad
_device_buf
.
GetDeviceBuffer
(),
lse
_device_buf
.
GetDeviceBuffer
(),
q
grad_device_buf
.
GetDeviceBuffer
(),
y
grad_device_buf
.
GetDeviceBuffer
(),
k
grad_device_buf
.
GetDeviceBuffer
(),
q
grad_device_buf
.
GetDeviceBuffer
(),
v
grad_device_buf
.
GetDeviceBuffer
(),
k
grad_device_buf
.
GetDeviceBuffer
(),
{},
// std::array<void*, 1> p_acc0_biases;
vgrad_device_buf
.
GetDeviceBuffer
(),
{},
// std::array<void*, 1> p_acc
1
_biases;
{},
// std::array<void*, 1> p_acc
0
_biases;
q_gs_ms_ks_lengths
,
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_
stride
s
,
q_gs_ms_ks_
length
s
,
k
_gs_
n
s_ks_
length
s
,
q
_gs_
m
s_ks_
stride
s
,
k_gs_ns_ks_
stride
s
,
k_gs_ns_ks_
length
s
,
z
_gs_
m
s_
n
s_
length
s
,
k
_gs_
n
s_
k
s_
stride
s
,
z_gs_ms_ns_
stride
s
,
z_gs_ms_ns_
length
s
,
v
_gs_
o
s_ns_
length
s
,
z
_gs_
m
s_ns_
stride
s
,
v_gs_os_ns_
stride
s
,
v_gs_os_ns_
length
s
,
y
_gs_
m
s_
o
s_
length
s
,
v
_gs_
o
s_
n
s_
stride
s
,
y_gs_ms_os_
stride
s
,
y_gs_ms_os_
length
s
,
lse_gs_ms_length
s
,
y_gs_ms_os_stride
s
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases
_gs_ms_
ns_
lengths
}
,
lse
_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_
stride
s},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_
length
s},
{},
// std::array<std::vector<ck::index_t>, 1>{acc
1
_biases_gs_ms_
o
s_
length
s},
{},
// std::array<std::vector<ck::index_t>, 1>{acc
0
_biases_gs_ms_
n
s_
stride
s},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_
stride
s},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_
length
s},
QKVElementOp
{
},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides
},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{
},
QKVElementOp
{
},
Scale
{
alpha
},
Y
ElementOp
{},
QKV
ElementOp
{},
p_drop
,
YElementOp
{}
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
...
@@ -496,40 +278,41 @@ int main(int argc, char* argv[])
...
@@ -496,40 +278,41 @@ int main(int argc, char* argv[])
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
q_device_buf
.
GetDeviceBuffer
(),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
k_device_buf
.
GetDeviceBuffer
(),
q_device_buf
.
GetDeviceBuffer
(),
nullptr
,
// set to nullptr
k_device_buf
.
GetDeviceBuffer
(),
v_device_buf
.
GetDeviceBuffer
(),
nullptr
,
// set to nullptr
y_device_buf
.
GetDeviceBuffer
(),
v_device_buf
.
GetDeviceBuffer
(),
lse_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
ygrad_device_buf
.
GetDeviceBuffer
(),
lse_device_buf
.
GetDeviceBuffer
(),
qgrad_device_buf
.
GetDeviceBuffer
(),
ygrad_device_buf
.
GetDeviceBuffer
(),
kgrad_device_buf
.
GetDeviceBuffer
(),
qgrad_device_buf
.
GetDeviceBuffer
(),
vgrad_device_buf
.
GetDeviceBuffer
(),
kgrad_device_buf
.
GetDeviceBuffer
(),
{},
// std::array<void*, 1> p_acc0_biases;
vgrad_device_buf
.
GetDeviceBuffer
(),
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc0_biases;
q_gs_ms_ks_lengths
,
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_strides
,
q_gs_ms_ks_lengths
,
k_gs_ns_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_strides
,
k_gs_ns_ks_lengths
,
z_gs_ms_ns_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_strides
,
z_gs_ms_ns_lengths
,
v_gs_os_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_strides
,
v_gs_os_ns_lengths
,
y_gs_ms_os_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_strides
,
y_gs_ms_os_lengths
,
lse_gs_ms_lengths
,
y_gs_ms_os_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
QKVElementOp
{},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
YElementOp
{},
QKVElementOp
{},
p_drop
,
YElementOp
{},
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
View file @
883c060a
...
@@ -18,15 +18,36 @@ namespace device {
...
@@ -18,15 +18,36 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
unsigned
short
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
...
@@ -34,31 +55,29 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
...
@@ -34,31 +55,29 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
unsigned
short
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
unsigned
short
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
...
@@ -68,7 +87,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
...
@@ -68,7 +87,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
BF16
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
...
@@ -76,29 +95,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
...
@@ -76,29 +95,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
unsigned
short
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
typename
InputDataType
,
template
<
typename
InputDataType
,
typename
OutputDataType
,
typename
OutputDataType
,
...
@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory<
{
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
{
{
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
op_ptrs
);
}
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
883c060a
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
//
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
@@ -50,9 +51,8 @@ template <index_t NumDimG,
...
@@ -50,9 +51,8 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
index_t
NumDimO
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
using
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
=
using
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
...
@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>
{});
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>
{});
}
}
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskDisabled
>
{});
MaskingSpecialization
::
MaskDisabled
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
883c060a
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
//
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
@@ -50,9 +51,8 @@ template <index_t NumDimG,
...
@@ -50,9 +51,8 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
index_t
NumDimO
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
using
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
=
using
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
...
@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>
{});
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>
{});
}
}
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskDisabled
>
{});
MaskingSpecialization
::
MaskDisabled
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment