Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
20fc6679
Commit
20fc6679
authored
Feb 23, 2023
by
danyao12
Browse files
attn bwd prototype1 bf16
parent
010ed35f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
4669 additions
and
9 deletions
+4669
-9
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
+884
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
+77
-0
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
..._multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
+1258
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
..._multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
+2369
-0
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+69
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
20fc6679
...
...
@@ -11,6 +11,7 @@ add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1_bf16 batched_multihead_attention_backward_pt1_bf16.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
20fc6679
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK
1
#define USING_MASK
0
#include <iostream>
#include <numeric>
...
...
@@ -268,8 +268,8 @@ int run(int argc, char* argv[])
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
128
;
ck
::
index_t
O
=
128
;
ck
::
index_t
G0
=
3
;
ck
::
index_t
G1
=
2
;
ck
::
index_t
G0
=
54
;
ck
::
index_t
G1
=
16
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
0 → 100644
View file @
20fc6679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Computation graph:
K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Kernel inputs:
Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
*/
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_HD32 0
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
BF16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskOutUpperTriangle
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
#endif
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// Headdim/K/O should be a multiple of 8, and it's only supported up to 64 in prototype1.
// If Headdim/K/O <= 32, ues 1st template.
// If 32 < Headdim/K/O <= 64, ues 2nd template.
#if USING_HD32
// 1st template
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#else
//2nd template
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
DataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
DataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
// Ref Gemm for backward pass
// fp16 in, fp16 out
using
ReferenceGemmGradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
DataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
#if USING_MASK
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
#endif
// P = Softmax(S)
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_drop_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
2
;
// method 1 will have slightly higher error; TODO: to investigate
bool
time_kernel
=
true
;
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
// 512
ck
::
index_t
N
=
512
;
// 512
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
G0
=
54
;
// 54
ck
::
index_t
G1
=
16
;
// 16
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ks: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
// calculate y & log-sum-exp beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
// get z matrix
{
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
}
// not need output z matrix
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
// 5 GEMM ops in total:
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// 3x MNK + 2x MNO
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std
::
size_t
num_btype
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
size_t
(
2
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
if
(
do_verification
)
{
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
// call kernel again
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
#if PRINT_HOST
{
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"v_g_n_o ref:
\n
"
<<
v_g_n_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
}
#endif
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
// dP_dropout = dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_drop_g_m_n ref:
\n
"
<<
pgrad_drop_g_m_n
;
}
#endif
// dP = dP_dropout x Z
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
#if PRINT_HOST
{
std
::
cout
<<
"===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
\n
"
;
std
::
cout
<<
"p_g_m_n ref:
\n
"
<<
p_g_m_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"y_g_m_o ref:
\n
"
<<
y_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
}
#endif
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_drop_g_n_m ref:
\n
"
<<
p_drop_g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
}
#endif
// dQ = alpha * dS * K
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"qgrad_g_m_k ref:
\n
"
<<
qgrad_g_m_k
;
}
#endif
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
std
::
cout
<<
"sgrad_g_n_m ref:
\n
"
<<
sgrad_g_n_m
;
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"kgrad_g_n_k ref:
\n
"
<<
kgrad_g_n_k
;
}
#endif
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
vgrad_device_buf
.
FromDevice
(
vgrad_gs_os_ns_device_result
.
mData
.
data
());
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
View file @
20fc6679
...
...
@@ -25,6 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_HD32 0
#include <iostream>
#include <numeric>
...
...
@@ -86,6 +87,80 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// Headdim/K/O should be a multiple of 8, and it's only supported up to 64 in prototype1.
// If Headdim/K/O <= 32, ues 1st template.
// If 32 < Headdim/K/O <= 64, ues 2nd template.
#if USING_HD32
// 1st template
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#else
//2nd template
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
...
...
@@ -125,6 +200,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -151,6 +227,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
...
...
include/ck/ck.hpp
View file @
20fc6679
...
...
@@ -118,7 +118,7 @@
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
0
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
// experimental feature: in-regsiter sub-dword transpose
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
20fc6679
...
...
@@ -204,6 +204,7 @@ template <index_t NumDimG,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -627,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm2NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
0 → 100644
View file @
20fc6679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_O0_M_O1
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
VGradGridDescriptor_N_O
vgrad_grid_desc_n_o
,
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
p_drop
;
ignore
=
seed
;
ignore
=
offset
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
// Gemm0NPerBlock
index_t
KPerBlock
,
// Gemm0KPerBlock
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
B1K1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
#if 0
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
Q_K1
=
8
;
static
constexpr
index_t
K_K1
=
8
;
static
constexpr
index_t
V_N1
=
2
;
static
constexpr
index_t
Q_M1
=
2
;
static
constexpr
index_t
K_N1
=
2
;
static
constexpr
index_t
V_O1
=
8
;
static
constexpr
index_t
Y_O1
=
8
;
static
constexpr
index_t
Y_M1
=
2
;
static
constexpr
auto
padder
=
GemmGemmPadder
<
GemmSpec
,
Number
<
MPerBlock
>
,
Number
<
NPerBlock
>
,
Number
<
KPerBlock
>
,
Number
<
Gemm1NPerBlock
>>
{};
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpec
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths_vec
,
b_gs_ns_ks_strides_vec
),
Number
<
BK1
>
{});
}
// V in Gemm B1 position
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides_vec
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
b1_gs_gemm1ns_gemm1ks_strides_vec
),
Number
<
B1K1
>
{});
}
//
// dV = P^T * dY
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides_vec
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
std
::
vector
<
index_t
>
gs_ids
(
NumDimG
);
std
::
iota
(
gs_ids
.
begin
(),
gs_ids
.
end
(),
0
);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std
::
vector
<
index_t
>
os_ids
(
NumDimO
);
std
::
iota
(
os_ids
.
begin
(),
os_ids
.
end
(),
NumDimG
);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std
::
vector
<
index_t
>
ns_ids
(
NumDimN
);
std
::
iota
(
ns_ids
.
begin
(),
ns_ids
.
end
(),
NumDimG
+
NumDimO
);
std
::
vector
<
index_t
>
ids_old2new
;
ids_old2new
.
insert
(
ids_old2new
.
end
(),
gs_ids
.
begin
(),
gs_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths_vec
(
num_dims
),
v_gs_ns_os_strides_vec
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths_vec
[
i
]
=
v_gs_os_ns_lengths_vec
[
id_new
];
v_gs_ns_os_strides_vec
[
i
]
=
v_gs_os_ns_strides_vec
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
}
template
<
typename
YGridDesc_M_O
>
static
auto
MakeYGradGridDescriptor_M0_O_M1
(
const
YGridDesc_M_O
&
ygrad_grid_desc_m_o
)
{
const
auto
M
=
ygrad_grid_desc_m_o
.
GetLength
(
I0
);
const
auto
O
=
ygrad_grid_desc_m_o
.
GetLength
(
I1
);
const
auto
Y_M0
=
M
/
Y_M1
;
return
transform_tensor_descriptor
(
ygrad_grid_desc_m_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_M0
,
Y_M1
)),
make_pass_through_transform
(
O
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides_vec
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths_vec
,
y_gs_ms_os_strides_vec
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides_vec
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
std
::
vector
<
index_t
>
gs_ids
(
NumDimG
);
std
::
iota
(
gs_ids
.
begin
(),
gs_ids
.
end
(),
0
);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std
::
vector
<
index_t
>
os_ids
(
NumDimO
);
std
::
iota
(
os_ids
.
begin
(),
os_ids
.
end
(),
NumDimG
);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std
::
vector
<
index_t
>
ns_ids
(
NumDimN
);
std
::
iota
(
ns_ids
.
begin
(),
ns_ids
.
end
(),
NumDimG
+
NumDimO
);
std
::
vector
<
index_t
>
ids_old2new
;
ids_old2new
.
insert
(
ids_old2new
.
end
(),
gs_ids
.
begin
(),
gs_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths_vec
(
num_dims
),
v_gs_ns_os_strides_vec
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths_vec
[
i
]
=
v_gs_os_ns_lengths_vec
[
id_new
];
v_gs_ns_os_strides_vec
[
i
]
=
v_gs_os_ns_strides_vec
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
// N_O to O0_N_O1; to refactor
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides_vec
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths_vec
,
z_gs_ms_ns_strides_vec
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides_vec
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths_vec
,
q_gs_ms_ks_strides_vec
);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides_vec
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths_vec
,
k_gs_ns_ks_strides_vec
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
lse_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
lse_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
lse_grid_desc_mraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
VGradGridDesc_N_O
=
decltype
(
MakeVGradGridDescriptor_N_O
({},
{}));
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
return
MaskOutUpperTrianglePredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
BatchStrideLSE_
(
BatchStrideLSE
)
{
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
{
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
return
b1_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
{
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetLSEBasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
DataType
,
// TODO: distinguish A/B datatype
LSEDataType
,
GemmAccDataType
,
CShuffleDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
LSEGridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm2NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
DataType
*
p_a_grid
,
const
DataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
const
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_z_grid_
{
p_z_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_lse_grid_
{
p_lse_grid
},
p_ygrad_grid_
{
p_ygrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
ygrad_grid_desc_o0_m_o1_
{
DeviceOp
::
MakeYGradGridDescriptor_O0_M_O1
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
// batch offsets
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
b1_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
y_grid_desc_m_o_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
b1_gs_gemm1ns_gemm1ks_lengths
[
NumDimG
+
NumDimO
-
1
]},
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
b_nz_kz_strides_
{
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
]},
b1_nz_kz_strides_
{
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
-
1
],
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())},
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ns_lengths
;
ignore
=
acc0_biases_gs_ms_ns_strides
;
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
y_grid_desc_m_o_
,
block_2_ctile_map_
))
{
y_grid_desc_mblock_mperblock_oblock_operblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
y_grid_desc_m_o_
);
}
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
// Print();
}
void
Print
()
const
{
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
// a_grid_desc_g_m_k_.Print();
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_o_n_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
std
::
cout
<<
"c_grid_desc_g_m_o_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
std
::
cout
<<
"vgrad_grid_desc_n_o_: "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I0
)
<<
", "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"ygrad_grid_desc_o0_m_o1_: "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I2
)
<<
'\n'
;
}
// pointers
const
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
VGradGridDesc_N_O
vgrad_grid_desc_n_o_
;
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1_
;
// batch offsets
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_mz_nz_kz_gemm1nz_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
std
::
vector
<
index_t
>
b_nz_kz_strides_
;
std
::
vector
<
index_t
>
b1_nz_kz_strides_
;
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
DeviceOp
::
IsSupportedArgument
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! unsupported argument"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
)
*
arg
.
batch_count_
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1
<
GridwiseGemm
,
DataType
,
ZDataType
,
LSEDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
VGradGridDesc_N_O
,
DeviceOp
::
YGradGridDesc_O0_M_O1
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
vgrad_grid_desc_n_o_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
#if 1
// if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
// {
// ave_time = launch_kernel(integral_constant<bool, true>{});
// }
// else
// {
// ave_time = launch_kernel(integral_constant<bool, false>{});
// }
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
#endif
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if 0
arg.Print();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
// TODO: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
{
return
false
;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
];
const
auto
NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
];
const
auto
KzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
2
];
const
auto
Gemm1NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
NzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
NzRaw
:
Gemm1NzRaw
;
const
auto
c_extent_lowest
=
Gemm1NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
arg
.
b_nz_kz_strides_
[
1
]
:
arg
.
b_nz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_kz_strides_
[
1
]
:
arg
.
b1_nz_kz_strides_
[
0
];
const
auto
c_stride_lowest
=
arg
.
c_mz_gemm1nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_m_o_
,
arg
.
block_2_ctile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
DataType
*
p_a
,
const
DataType
*
p_b
,
ZDataType
*
p_z
,
const
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
return
Argument
{
p_a
,
p_b
,
p_z
,
p_b1
,
p_c
,
p_lse
,
p_ygrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_vgrad_grid
,
p_acc0_biases
,
p_acc1_biases
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
p_drop
,
seeds
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// FIXME: constness
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_z
,
const
void
*
p_b1
,
const
void
*
p_c
,
const
void
*
p_lse
,
const
void
*
p_ygrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
static_cast
<
DataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
acc1_biases_gs_ms_gemm1ns_strides
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
p_drop
,
seeds
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
// override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
BSpec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
20fc6679
...
...
@@ -51,6 +51,7 @@ template <typename DataType,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -726,9 +727,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmNWave
=
Free0_N
/
Gemm2NXdlPerWave
/
MPerXdl
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMLoop
=
Free1_M
/
Sum_M
;
static
constexpr
index_t
GemmMPack
=
...
...
@@ -1563,8 +1564,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
...
@@ -1937,7 +1938,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
else
{
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
}
});
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
0 → 100644
View file @
20fc6679
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace
ck
{
template
<
typename
DataType
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatLSE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
SElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
QGridDesc_K0_M_K1
,
typename
KGridDesc_K0_N_K1
,
typename
ZGridDesc_M_N
,
typename
VGridDesc_O0_N_O1
,
typename
CGridDesc_M_N
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
B1K1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
index_t
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
template
<
typename
AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
>
__host__
__device__
static
constexpr
auto
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
const
AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
&
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
{
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to a_src_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const
auto
m0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
const
auto
n0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
const
auto
m1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
const
auto
n1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
const
auto
m2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
const
auto
n2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
const
auto
n3
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
const
auto
n4
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
return
transform_tensor_descriptor
(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
template
<
typename
Gemm2Param
>
__host__
__device__
static
constexpr
auto
GetA2BlockDescriptor_M0_N_M1
()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
Gemm2Param
::
A_M0
>
{},
Number
<
Gemm2Param
::
Free0_N
>
{},
Number
<
Gemm2Param
::
A_M1
>
{}),
make_tuple
(
Number
<
Gemm2Param
::
Free0_N
+
Gemm2Param
::
A_LdsPad
>
{}
*
Number
<
Gemm2Param
::
A_M1
>
{},
Number
<
Gemm2Param
::
A_M1
>
{},
I1
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
// check gemm1 gridwise gemm pipeline
if
(
!
(
NPerBlock
%
Gemm1KPerBlock
==
0
))
{
return
false
;
}
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
// __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
// {
// const index_t num_loop = K / KPerBlock;
// return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
// }
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
Gemm1NPerBlock
;
const
auto
y_grid_desc_mblock_mperblock_oblock_operblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
y_grid_desc_mblock_mperblock_oblock_operblock
;
}
__host__
__device__
static
constexpr
auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
const
LSEGridDesc_M
&
lse_grid_desc_m
)
{
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
const
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}));
return
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
// Q / K / V / dY
struct
GemmBlockwiseCopy
{
// Q matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
q_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// K matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// V matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
v_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// dY matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
ygrad_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// // A matrix in LDS memory, dst of blockwise copy
// static constexpr auto a_block_desc_ak0_m_ak1 =
// GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// // B matrix in LDS memory, dst of blockwise copy
// static constexpr auto b_block_desc_bk0_n_bk1 =
// GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template
<
typename
GridDesc_K0_M_K1
>
using
QBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_M_K1
,
decltype
(
q_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
KBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_N_K1
,
decltype
(
k_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
VBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_N_K1
,
decltype
(
v_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_M_K1
>
using
YGradBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_M_K1
,
decltype
(
ygrad_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
auto
gemm_tile_k_block_slice_copy_step
=
make_multi_index
(
0
,
NPerBlock
,
0
);
static
constexpr
auto
gemm_tile_v_block_slice_copy_step
=
make_multi_index
(
0
,
NPerBlock
,
0
);
};
// S / dP Gemm (type 1 rcr)
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
;
// TransposeC
};
// Y / dQ Gemm (type 2 rrr)
template
<
typename
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
,
typename
ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
>
struct
Gemm1
{
private:
static
constexpr
auto
m0
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I0
);
static
constexpr
auto
n0
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I1
);
static
constexpr
auto
m1
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I2
);
static
constexpr
auto
n1
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I3
);
static
constexpr
auto
m2
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I4
);
static
constexpr
auto
n2
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I5
);
static
constexpr
auto
n3
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I6
);
static
constexpr
auto
n4
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I7
);
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
static
constexpr
auto
N3
=
ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I6
);
public:
static
constexpr
auto
AThreadSliceLength_K0
=
Number
<
Gemm1KPerBlock
/
n4
/
N3
>
{};
static
constexpr
auto
AThreadSliceLength_M
=
Number
<
m0
*
m1
*
m2
>
{};
static
constexpr
auto
AThreadSliceLength_K1
=
Number
<
n4
>
{};
// A source matrix layout in AccVGPR
static
constexpr
auto
a_src_thread_desc_k0_m_k1
=
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static
constexpr
auto
a_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
AThreadSliceLength_K0
,
AThreadSliceLength_M
,
AThreadSliceLength_K1
));
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bn0_k_bn1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
1
,
1
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
1
,
1
>
(
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
auto
ASrcScalarPerVector
=
n4
;
using
AThreadSliceLengths_K0_M_K1
=
decltype
(
a_thread_desc_k0_m_k1
.
GetLengths
());
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
AThreadSliceLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
2
,
ASrcScalarPerVector
>
;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static
constexpr
auto
a_block_slice_copy_step
=
make_tuple
(
AThreadSliceLength_K0
,
I0
,
I0
);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmMWave
=
Gemm0MWaves
;
static
constexpr
index_t
GemmNWave
=
Gemm0NWaves
;
static
constexpr
index_t
GemmMRepeat
=
MXdlPerWave
;
static
constexpr
index_t
GemmNRepeat
=
Gemm1NXdlPerWave
;
static
constexpr
index_t
GemmKLoop
=
NPerBlock
/
Gemm1KPerBlock
;
// 128/32=4
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 4
static
constexpr
index_t
B_K2
=
N3
;
// 2
static
constexpr
index_t
B_K1
=
Gemm1KPerBlock
/
B_K2
/
B_K3
;
// 4
static
constexpr
index_t
B_K0
=
GemmKLoop
;
// 4
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
()
{
const
auto
N0_
=
b_block_desc_bn0_k_bn1
.
GetLength
(
I0
);
const
auto
K_
=
b_block_desc_bn0_k_bn1
.
GetLength
(
I1
);
const
auto
N1_
=
b_block_desc_bn0_k_bn1
.
GetLength
(
I2
);
constexpr
auto
b_block_desc_n_k
=
transform_tensor_descriptor
(
//(64, 128)
b_block_desc_bn0_k_bn1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
N0_
,
N1_
)),
//(8, 8)
make_pass_through_transform
(
K_
)),
// 128
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_block_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmNRepeat
,
GemmNWave
,
NPerXdl
)),
//(2, 1, 32)
make_unmerge_transform
(
make_tuple
(
B_K0
,
B_K1
,
B_K2
,
B_K3
))),
//(4, 4, 2, 4)
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}));
}
static
constexpr
auto
b_block_desc_n0_n1_n2_k0_k1_k2_k3
=
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
();
using
BThreadSlice_N0_N1_N2_K0_K1_K2_K3
=
Sequence
<
GemmNRepeat
,
1
,
1
,
1
,
B_K1
,
1
,
B_K3
>
;
static
constexpr
auto
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
GemmNRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{},
I1
,
Number
<
B_K3
>
{}));
__host__
__device__
static
constexpr
auto
MakeBThreadDesc_K0_N_K1
()
{
constexpr
auto
b_thread_desc_n_k
=
transform_tensor_descriptor
(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
GemmNRepeat
>
{},
I1
,
I1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
I1
,
Number
<
B_K1
>
{},
I1
,
Number
<
B_K3
>
{}))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_thread_desc_n_k
,
make_tuple
(
make_pass_through_transform
(
Number
<
GemmNRepeat
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
B_K1
>
{},
Number
<
B_K3
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
}
static
constexpr
auto
b_thread_desc_k0_n_k1
=
MakeBThreadDesc_K0_N_K1
();
using
BBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
b_block_desc_n0_n1_n2_k0_k1_k2_k3
),
decltype
(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
1
,
1
,
true
>
;
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
);
static
constexpr
auto
b_block_reset_copy_step
=
make_multi_index
(
0
,
0
,
0
,
-
B_K0
,
0
,
0
,
0
);
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_thread_desc_k0_n_k1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b_thread_desc_k0_n_k1
)),
MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
Gemm1NXdlPerWave
,
GemmKPack
,
true
,
// TransposeC
GemmKPack
,
// AMmaKStride
GemmKPack
>
;
};
// dV / dK Gemm (type 3 crr)
// Describes tuning parameter for C2_n_o = A2_n_m * B2_m_o
template
<
index_t
Sum_M_
=
MPerXdl
*
2
>
struct
Gemm2Params_N_O_M_
{
static
constexpr
index_t
Free0_N
=
NPerBlock
;
static
constexpr
index_t
Free1_M
=
MPerBlock
;
static
constexpr
index_t
Free1_O
=
Gemm1NPerBlock
;
static
constexpr
index_t
Sum_M
=
Sum_M_
;
static
constexpr
index_t
A_M1
=
8
;
// P will be row-major
static
constexpr
index_t
A_M0
=
Sum_M
/
A_M1
;
static
constexpr
index_t
A_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static
constexpr
index_t
GemmNWave
=
Free0_N
/
Gemm2NXdlPerWave
/
MPerXdl
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMLoop
=
Free1_M
/
Sum_M
;
static
constexpr
index_t
GemmMPack
=
math
::
max
(
A_M1
,
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
B_M3
=
GemmMPack
;
// 8
static
constexpr
index_t
B_M2
=
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
GemmMPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_M1
=
Sum_M
/
B_M2
/
B_M3
;
// 4
static
constexpr
index_t
B_M0
=
GemmMLoop
;
// 2
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_N0_M1_N1_M2_N2
()
{
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr
index_t
m
=
Gemm2Params_N_O_M
::
Sum_M
-
1
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr
index_t
n
=
Gemm2Params_N_O_M
::
Free0_N
-
1
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_N0_M1_N1
()
{
return
generate_sequence_v2
(
[](
auto
I
)
{
return
GetABlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
Number
<
4
>
{});
}
using
ABlockSliceLengths_M0_N0_M1_N1
=
decltype
(
GetABlockSliceLengths_M0_N0_M1_N1
());
};
using
Gemm2Params_N_O_M
=
Gemm2Params_N_O_M_
<>
;
// tune later
// dV / dK Gemm (type 3 crr)
template
<
typename
Gemm2Params_N_O_M
,
typename
ASrcBlockwiseGemm
>
struct
Gemm2
{
private:
static
constexpr
auto
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
ASrcBlockwiseGemm
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
static
constexpr
auto
M0
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
// repeat
static
constexpr
auto
N0
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
static
constexpr
auto
M1
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
// wave
static
constexpr
auto
N1
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
static
constexpr
auto
M2
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
// xdl
static
constexpr
auto
N2
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
static
constexpr
auto
N3
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
static
constexpr
auto
N4
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static
constexpr
auto
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
ASrcBlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_m0_n_m1
=
GetA2BlockDescriptor_M0_N_M1
<
Gemm2Params_N_O_M
>
();
// // B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_o0_m_o1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
template
<
typename
ABlockDesc_M0_N_M1
>
__host__
__device__
static
constexpr
auto
MakeGemm2AMmaTileDescriptor_N0_N1_N2_M
(
const
ABlockDesc_M0_N_M1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm2Params_N_O_M
::
GemmNRepeat
,
Gemm2Params_N_O_M
::
GemmNWave
,
MPerXdl
>
(
ABlockDesc_M0_N_M1
{});
}
template
<
typename
BBlockDesc_M0_O_M1
>
__host__
__device__
static
constexpr
auto
MakeGemm2BMmaTileDescriptor_O0_O1_O2_M
(
const
BBlockDesc_M0_O_M1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm2Params_N_O_M
::
GemmORepeat
,
1
,
1
>
(
BBlockDesc_M0_O_M1
{});
}
__host__
__device__
static
constexpr
auto
MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
()
{
const
auto
M0_
=
a_block_desc_m0_n_m1
.
GetLength
(
I0
);
const
auto
N_
=
a_block_desc_m0_n_m1
.
GetLength
(
I1
);
const
auto
M1_
=
a_block_desc_m0_n_m1
.
GetLength
(
I2
);
const
auto
a_block_desc_m_n
=
transform_tensor_descriptor
(
a_block_desc_m0_n_m1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
M0_
,
M1_
)),
make_pass_through_transform
(
N_
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return
transform_tensor_descriptor
(
a_block_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
I1
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__
__device__
static
auto
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
()
{
const
auto
a_thread_origin_on_block_idx
=
ASrcBlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
c_block_slice_lengths_m0_n0_m1_n1
=
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
{};
// mrepeat, nrepeat,
// mwaves, nwaves,
return
make_tuple
(
a_thread_origin_on_block_idx
[
I0
],
// mrepeat
a_thread_origin_on_block_idx
[
I1
],
// nrepeat
a_thread_origin_on_block_idx
[
I2
]
%
c_block_slice_lengths_m0_n0_m1_n1
[
I2
],
// mwave
a_thread_origin_on_block_idx
[
I3
]
%
c_block_slice_lengths_m0_n0_m1_n1
[
I3
],
// nwave
a_thread_origin_on_block_idx
[
I4
],
// xdlops
a_thread_origin_on_block_idx
[
I5
],
a_thread_origin_on_block_idx
[
I6
],
a_thread_origin_on_block_idx
[
I7
]);
}
static
constexpr
auto
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
();
using
ASrcBlockSliceWindowIterator
=
SpaceFillingCurve
<
Sequence
<
M0
,
N0
,
M1
,
N1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
,
false
>
;
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
ElementwiseOp
,
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I0
),
// ThreadSliceLengths
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
;
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_O0_O1_O2_M0_M1_M2_M3
()
{
const
auto
O0_
=
b_block_desc_o0_m_o1
.
GetLength
(
I0
);
const
auto
M_
=
b_block_desc_o0_m_o1
.
GetLength
(
I1
);
const
auto
O1_
=
b_block_desc_o0_m_o1
.
GetLength
(
I2
);
constexpr
auto
b_block_desc_o_m
=
transform_tensor_descriptor
(
//(64, 128)
b_block_desc_o0_m_o1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
O0_
,
O1_
)),
//(8, 8)
make_pass_through_transform
(
M_
)),
// 128
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_block_desc_o_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Gemm2Params_N_O_M
::
GemmORepeat
,
Gemm2Params_N_O_M
::
GemmOWave
,
NPerXdl
)),
//(1, 2, 32)
make_unmerge_transform
(
make_tuple
(
Gemm2Params_N_O_M
::
B_M0
,
Gemm2Params_N_O_M
::
B_M1
,
Gemm2Params_N_O_M
::
B_M2
,
Gemm2Params_N_O_M
::
B_M3
))),
//(2, 4, 2, 8)
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}));
}
static
constexpr
auto
b_block_desc_o0_o1_o2_m0_m1_m2_m3
=
MakeBBlockDesc_O0_O1_O2_M0_M1_M2_M3
();
using
BThreadSlice_O0_O1_O2_M0_M1_M2_M3
=
Sequence
<
Gemm2Params_N_O_M
::
GemmORepeat
,
1
,
1
,
1
,
Gemm2Params_N_O_M
::
B_M1
,
1
,
Gemm2Params_N_O_M
::
B_M3
>
;
static
constexpr
auto
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
Gemm2Params_N_O_M
::
GemmORepeat
>
{},
I1
,
I1
,
I1
,
Number
<
Gemm2Params_N_O_M
::
B_M1
>
{},
I1
,
Number
<
Gemm2Params_N_O_M
::
B_M3
>
{}));
__host__
__device__
static
constexpr
auto
MakeBThreadDesc_M0_O_M1
()
{
constexpr
auto
b_thread_desc_o_m
=
transform_tensor_descriptor
(
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
Gemm2Params_N_O_M
::
GemmORepeat
>
{},
I1
,
I1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
I1
,
Number
<
Gemm2Params_N_O_M
::
B_M1
>
{},
I1
,
Number
<
Gemm2Params_N_O_M
::
B_M3
>
{}))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_thread_desc_o_m
,
make_tuple
(
make_pass_through_transform
(
Number
<
Gemm2Params_N_O_M
::
GemmORepeat
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
Gemm2Params_N_O_M
::
B_M1
>
{},
Number
<
Gemm2Params_N_O_M
::
B_M3
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
}
static
constexpr
auto
b_thread_desc_m0_o_m1
=
MakeBThreadDesc_M0_O_M1
();
using
BBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
b_block_desc_o0_o1_o2_m0_m1_m2_m3
),
decltype
(
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
),
BThreadSlice_O0_O1_O2_M0_M1_M2_M3
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
1
,
1
,
true
>
;
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
);
static
constexpr
auto
b_block_reset_copy_step
=
make_multi_index
(
0
,
0
,
0
,
-
Gemm2Params_N_O_M
::
B_M0
,
0
,
0
,
0
);
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
b_thread_desc_m0_o_m1
),
decltype
(
MakeGemm2AMmaTileDescriptor_N0_N1_N2_M
(
a_block_desc_m0_n_m1
)),
decltype
(
MakeGemm2BMmaTileDescriptor_O0_O1_O2_M
(
b_thread_desc_m0_o_m1
)),
NPerBlock
,
Gemm1NPerBlock
,
Gemm2Params_N_O_M
::
Sum_M
,
MPerXdl
,
NPerXdl
,
Gemm2Params_N_O_M
::
GemmNRepeat
,
Gemm2Params_N_O_M
::
GemmORepeat
,
Gemm2Params_N_O_M
::
GemmMPack
,
true
,
// TransposeC
Gemm2Params_N_O_M
::
GemmMPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
Gemm2Params_N_O_M
::
GemmMPack
,
false
>
{}
.
K0PerXdlops
,
Gemm2Params_N_O_M
::
GemmMPack
>
;
static
constexpr
auto
c_block_slice_copy_step
=
make_multi_index
(
Gemm2Params_N_O_M
::
GemmNRepeat
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
template
<
typename
CGradDesc_N_O
>
__host__
__device__
static
auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
const
CGradDesc_N_O
&
c_grid_desc_n_o
)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const
auto
c_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
c_grid_desc_n_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
Gemm2Params_N_O_M
::
GemmNWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
Gemm2Params_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
const
auto
c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
BlockwiseGemm
{}.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_grid_desc_n0_o0_n1_o1_n2_o2
);
return
c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
;
}
static
constexpr
auto
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
BlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
__host__
__device__
static
auto
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
{
return
to_multi_index
(
BlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
));
}
template
<
typename
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
ElementwiseOp
,
// CElementwiseOperation
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
// dY matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
ygrad_block_desc_o0_m_o1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
__host__
__device__
static
constexpr
auto
MakeYGradBlockDesc_M_O
()
{
const
auto
O0_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I0
);
const
auto
M_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I1
);
const
auto
O1_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I2
);
static_assert
(
O0_
*
O1_
==
BlockSliceLength_O_
,
""
);
static_assert
(
M_
==
BlockSliceLength_M_
,
""
);
return
transform_tensor_descriptor
(
//(128, 64)
ygrad_block_desc_o0_m_o1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
O0_
,
O1_
)),
//(8, 8)
make_pass_through_transform
(
M_
)),
// 128
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
}
static
constexpr
auto
ygrad_block_desc_m_o
=
MakeYGradBlockDesc_M_O
();
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
struct
QGradGemmTile_M_K_N
{
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
auto
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
K0
*
K1
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
KBlock
=
K
/
Gemm1NPerBlock
;
// NOTE: QGrad gemm is similar to Y gemm
const
auto
q_grid_desc_m_k
=
transform_tensor_descriptor
(
q_grid_desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
M
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
q_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
KBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
};
struct
KGradGemmTile_N_K_M
{
// C position
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
auto
MakeKGradGridDesc_N_K
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K_K1
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
k_grid_desc_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
N
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K_K0
,
K_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
};
struct
SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static
constexpr
auto
q_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
v_block_desc_k0_n_k1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
ygrad_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
p_slash_sgrad_block_desc_m0_n_m1
=
GetA2BlockDescriptor_M0_N_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
q_block_space_size_aligned
=
math
::
integer_least_multiple
(
q_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
v_block_space_size_aligned
=
math
::
integer_least_multiple
(
v_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_slash_sgrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_slash_sgrad_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_offset
=
0
;
static
constexpr
auto
q_block_space_offset
=
ygrad_block_space_size_aligned
.
value
;
static
constexpr
auto
k_block_space_offset
=
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
;
static
constexpr
auto
v_block_space_offset
=
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
;
static
constexpr
auto
p_slash_sgrad_block_space_offset
=
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
(
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
DataType
)
/
sizeof
(
FloatGemmAcc
);
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
k_bytes_end
=
(
SharedMemTrait
::
k_block_space_offset
+
SharedMemTrait
::
k_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
v_bytes_end
=
(
SharedMemTrait
::
v_block_space_offset
+
SharedMemTrait
::
v_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
p_slash_sgrad_bytes_end
=
(
SharedMemTrait
::
p_slash_sgrad_block_space_offset
+
SharedMemTrait
::
p_slash_sgrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
k_bytes_end
,
v_bytes_end
,
p_slash_sgrad_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_O0_M_O1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
SElementwiseOperation
&
s_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
VGradGridDescriptor_N_O
&
vgrad_grid_desc_n_o
,
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
ck
::
philox
&
ph
)
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
const
auto
q_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_q_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
k_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_k_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
const
auto
v_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_v_grid
,
v_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_o0_m_o1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [M, O]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I0
),
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
// const index_t o_block_data_idx_on_grid =
// __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// LDS allocation for Q / K / V / dY
auto
q_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
q_block_space_offset
,
GemmBlockwiseCopy
::
q_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
k_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
k_block_space_offset
,
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
v_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
v_block_space_offset
,
GemmBlockwiseCopy
::
v_block_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
ygrad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
GemmBlockwiseCopy
::
ygrad_block_desc_k0_m_k1
.
GetElementSpaceSize
());
// Q matrix blockwise copy
auto
gemm_tile_q_blockwise_copy
=
typename
GemmBlockwiseCopy
::
template
QBlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
GemmBlockwiseCopy
::
q_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// K matrix blockwise copy
auto
gemm_tile_k_blockwise_copy
=
typename
GemmBlockwiseCopy
::
template
KBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
k_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
b_element_op
,
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// V matrix blockwise copy
auto
gemm_tile_v_blockwise_copy
=
typename
GemmBlockwiseCopy
::
template
VBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
b1_element_op
,
GemmBlockwiseCopy
::
v_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dY matrix blockwise copy
auto
gemm_tile_ygrad_blockwise_copy
=
typename
GemmBlockwiseCopy
::
template
YGradBlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
GemmBlockwiseCopy
::
ygrad_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
//
// set up S / dP Gemm (type 1 rcr)
//
// S: blockwise gemm
auto
s_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
// TransposeC
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
//
// set up Y / dQ Gemm (type 2 rrr)
//
// Note: Y is pre-calculated in forward pass and loaded to backward pass kernel
using
Gemm1
=
Gemm1
<
decltype
(
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()),
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Gemm1: VGPR allocation for A and B
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm1_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
Gemm1
::
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
.
GetElementSpaceSize
());
// dQ: transform input and output tensor descriptors
auto
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
=
QGradGemmTile_M_K_N
::
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
q_grid_desc_k0_m_k1
);
// dQ: A matrix blockwise copy
auto
qgrad_gemm_tile_sgrad_blockwise_copy
=
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ: blockwise gemm
auto
qgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
),
make_tuple
(
0
,
0
,
0
,
0
)};
// dQ: B matrix blockwise copy
auto
k_thread_origin
=
qgrad_blockwise_gemm
.
CalculateBThreadOriginDataIndex
();
auto
qgrad_gemm_tile_k_blockwise_copy
=
typename
Gemm1
::
BBlockwiseCopy
{
Gemm1
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
make_multi_index
(
0
,
// nrepeat
k_thread_origin
[
I1
],
// nwave
k_thread_origin
[
I2
],
// nperxdl
0
,
// k0
0
,
// k1
k_thread_origin
[
I3
]
/
Gemm1
::
GemmKPack
,
// k2
0
)};
// k3
auto
qgrad_thread_buf
=
qgrad_blockwise_gemm
.
GetCThreadBuffer
();
//
// Blockwise softmax
//
// get acc0 8D thread cluster
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
tm0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I0
);
constexpr
auto
tn0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I1
);
constexpr
auto
tm1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I2
);
constexpr
auto
tn1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I3
);
constexpr
auto
tm2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I4
);
constexpr
auto
tn2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I5
);
constexpr
auto
tn3
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I6
);
constexpr
auto
tn4
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I7
);
// get acc0 thread map
constexpr
auto
m0_n_m1_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
tm0
*
tm1
,
tm2
)),
make_pass_through_transform
(
I1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
threadid_to_m0_n_m1_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
tm0
*
tm1
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
,
tm2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
threadid_to_m_n_thread_cluster_adaptor
=
chain_tensor_adaptors
(
m0_n_m1_to_m_n_adaptor
,
threadid_to_m0_n_m1_adaptor
);
// get acc0 2D thread cluster & 2D thread slice
constexpr
auto
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
thread_cluster_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
tm0
*
tm1
*
tm2
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
));
constexpr
auto
thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
*
m1
*
m2
,
n0
*
n1
*
n2
*
n3
*
n4
));
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_16bits
,
rp_dropout
};
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
constexpr
auto
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
m0
,
m1
,
m2
));
auto
lse_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatLSE
>
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
auto
acc0_thread_origin
=
s_blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
auto
lse_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatLSE
,
FloatLSE
,
decltype
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// registerNum
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
//
// set up dV / dK Gemm (type 3 crr)
//
using
Gemm2
=
Gemm2
<
Gemm2Params_N_O_M
,
decltype
(
s_blockwise_gemm
)
>
;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_slash_sgrad_block_space_offset
,
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
gemm2_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
Gemm2
::
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
.
GetElementSpaceSize
());
// dV: transform input and output tensor descriptors
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
// dV: A matrix VGPR-to-LDS blockwise copy
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
Relu
>{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
Relu
{}};
// relu(P-dropped)
// dV: blockwise gemm
auto
v_slash_k_grad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
v_slash_k_grad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
auto
q_slash_ygrad_thread_origin
=
v_slash_k_grad_blockwise_gemm
.
CalculateBThreadOriginDataIndex
();
// dV: B matrix LDS-to-VGPR blockwise copy
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm2
::
BBlockwiseCopy
{
Gemm2
::
b_block_desc_o0_o1_o2_m0_m1_m2_m3
,
make_multi_index
(
0
,
// orepeat
q_slash_ygrad_thread_origin
[
I1
],
// owave
q_slash_ygrad_thread_origin
[
I2
],
// nperxdl
0
,
// m0
0
,
// m1
q_slash_ygrad_thread_origin
[
I3
]
/
Gemm2Params_N_O_M
::
GemmMPack
,
// m2
0
)};
// m3
auto
v_slash_k_grad_thread_buf
=
v_slash_k_grad_blockwise_gemm
.
GetCThreadBuffer
();
// dV: C VGPR-to-global copy
const
auto
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
make_multi_index
(
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
vgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
tensor_operation
::
element_wise
::
Scale
>(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
// dK: transform output tensor descriptors
const
auto
kgrad_grid_desc_n_k
=
KGradGemmTile_N_K_M
::
MakeKGradGridDesc_N_K
(
k_grid_desc_k0_n_k1
);
const
auto
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
// dK: A matrix VGPR-to-LDS blockwise copy
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
PassThrough
>{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
// dK: B matrix LDS-to-VGPR blockwise copy
auto
kgrad_gemm_tile_q_blockwise_copy
=
typename
Gemm2
::
BBlockwiseCopy
{
Gemm2
::
b_block_desc_o0_o1_o2_m0_m1_m2_m3
,
make_multi_index
(
0
,
// orepeat
q_slash_ygrad_thread_origin
[
I1
],
// owave
q_slash_ygrad_thread_origin
[
I2
],
// nperxdl
0
,
// m0
0
,
// m1
q_slash_ygrad_thread_origin
[
I3
]
/
Gemm2Params_N_O_M
::
GemmMPack
,
// m2
0
)};
// m3
// dK: blockwise gemm
/* reuse v_slash_k_grad_blockwise_gemm, v_slash_k_grad_thread_buf */
// dK: C VGPR-to-global copy
const
auto
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
make_multi_index
(
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
kgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
scale_rp_dropout
)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
scale_rp_dropout
);
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
constexpr
auto
P_M2
=
p_block_lengths
[
I4
];
// xdl
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
ygrad_thread_desc_m_o
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
YDotYGrad_M_O
::
ThreadSliceLength_M
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
const
auto
y_thread_cluster_idx
=
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
constexpr
auto
ygrad_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
YDotYGrad_M_O
::
ThreadClusterLength_M
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
>
{});
const
auto
ygrad_thread_cluster_idx
=
ygrad_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
ygrad_thread_data_on_block_idx
=
ygrad_thread_cluster_idx
*
ygrad_thread_desc_m_o
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
block_work_idx
[
I0
],
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
y_thread_data_on_block_idx
;
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
// performs for ygrad
auto
ygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
FloatGemmAcc
,
decltype
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
.
GetLengths
()),
Sequence
<
0
,
1
>
,
1
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
>
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
,
ygrad_thread_data_on_block_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_O
::
DstBufType
{};
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
MPerBlock
);
constexpr
auto
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
),
make_tuple
(
P_M0
*
P_M1
*
P_M2
,
P_M1
*
P_M2
,
P_M2
,
I1
));
constexpr
auto
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
=
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
;
// reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
// load ygrad
gemm_tile_ygrad_blockwise_copy
.
Run
(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_buf
,
GemmBlockwiseCopy
::
ygrad_block_desc_k0_m_k1
,
ygrad_block_buf
,
I0
);
block_sync_lds
();
//
// calculate Y dot dY
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
y_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_thread_buf
);
ygrad_threadwise_copy
.
Run
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
,
ygrad_block_buf
,
ygrad_thread_desc_m_o
,
make_tuple
(
I0
,
I0
),
ygrad_thread_buf
);
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_O
,
1
>
{}([
&
](
auto
iO
)
{
constexpr
auto
y_offset
=
y_thread_desc_m0_m1_o0_o1
.
CalculateOffset
(
make_multi_index
(
I0
,
iM
,
I0
,
iO
));
constexpr
auto
ygrad_offset
=
ygrad_thread_desc_m_o
.
CalculateOffset
(
make_multi_index
(
iM
,
iO
));
y_dot_ygrad_thread_accum_buf
(
iM
)
+=
y_thread_buf
[
Number
<
y_offset
>
{}]
*
ygrad_thread_buf
[
Number
<
ygrad_offset
>
{}];
});
});
// blockwise reduction using atomic_add
block_sync_lds
();
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]
*
p_dropout
);
// p_dropoutD1
});
block_sync_lds
();
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr
.
Run
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
y_dot_ygrad_block_accum_buf
,
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_buf
,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// Initialize dQ
qgrad_thread_buf
.
Clear
();
// load q
gemm_tile_q_blockwise_copy
.
Run
(
q_grid_desc_k0_m_k1
,
q_grid_buf
,
GemmBlockwiseCopy
::
q_block_desc_k0_m_k1
,
q_block_buf
,
I0
);
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
{
auto
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
{
continue
;
}
// gemm dP
// dP = dY * V^T
pgrad_thread_buf
.
Clear
();
gemm_tile_v_blockwise_copy
.
Run
(
v_grid_desc_o0_n_o1
,
v_grid_buf
,
GemmBlockwiseCopy
::
v_block_desc_k0_n_k1
,
v_block_buf
,
I0
);
block_sync_lds
();
pgrad_blockwise_gemm
.
Run
(
ygrad_block_buf
,
v_block_buf
,
pgrad_thread_buf
);
// gemm S
// S = Q * K^T
s_slash_p_thread_buf
.
Clear
();
gemm_tile_k_blockwise_copy
.
Run
(
k_grid_desc_k0_n_k1
,
k_grid_buf
,
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
,
k_block_buf
,
I0
);
block_sync_lds
();
s_blockwise_gemm
.
Run
(
q_block_buf
,
k_block_buf
,
s_slash_p_thread_buf
);
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
s_slash_p_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
}
});
}
else
{
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
if
(
p_z_grid
)
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
}
else
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
}
block_sync_lds
();
// wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY)
auto
&
sgrad_thread_buf
=
pgrad_thread_buf
;
constexpr
auto
pgrad_thread_tile_iterator
=
pgrad_blockwise_gemm
.
MakeCThreadTileIterator
();
constexpr
auto
pgrad_thread_idx_to_m_n_adaptor
=
pgrad_blockwise_gemm
.
MakeCThreadIndexAdaptor8DTo2D
();
static_for
<
0
,
pgrad_thread_tile_iterator
.
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
pgrad_thread_idx
=
pgrad_thread_tile_iterator
.
GetIndex
(
i
);
constexpr
auto
m
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
// dS and P has same thread buf layout
if
(
s_slash_p_thread_buf
[
i
]
>=
0
)
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
}
else
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}];
}
});
// gemm dQ
// dQ = scalar * dS * K
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// main body
static_for
<
0
,
num_gemm1_k_block_inner_loop
,
1
>
{}([
&
](
auto
i
)
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
Gemm1
::
a_src_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
i
,
sgrad_thread_buf
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
gemm1_a_thread_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
Run
(
Gemm1
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
k_block_buf
,
Gemm1
::
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
gemm1_b_thread_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
Gemm1
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
Gemm1
::
b_block_slice_copy_step
);
block_sync_lds
();
qgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_thread_buf
,
qgrad_thread_buf
);
// block_sync_lds();
});
}
// end gemm dQ
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
constexpr
index_t
num_gemm2_loop
=
MPerBlock
/
Gemm2Params_N_O_M
::
Sum_M
;
static_assert
(
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetNumOfAccess
()
==
num_gemm2_loop
,
""
);
// TODO: tune gemm2 pipeline
// dV = P_drop^T * dY
v_slash_k_grad_thread_buf
.
Clear
();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dV
// load VGrad Gemm A
const
auto
p_slice_idx
=
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
p_slice_idx
[
I2
],
p_slice_idx
[
I2
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I2
));
constexpr
auto
nwave_range
=
make_tuple
(
p_slice_idx
[
I3
],
p_slice_idx
[
I3
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I3
));
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
.
Run
(
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
p_slice_idx
[
I0
],
p_slice_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
s_slash_p_thread_buf
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
gemm2_a_block_buf
);
}
// block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy
.
Run
(
Gemm2
::
b_block_desc_o0_o1_o2_m0_m1_m2_m3
,
ygrad_block_buf
,
Gemm2
::
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
gemm2_b_thread_buf
);
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
Gemm2
::
b_block_desc_o0_o1_o2_m0_m1_m2_m3
,
Gemm2
::
b_block_slice_copy_step
);
block_sync_lds
();
// sync before read
v_slash_k_grad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_thread_buf
,
v_slash_k_grad_thread_buf
);
});
// end gemm dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
v_slash_k_grad_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_buf
);
// dK = scalar * dS^T * dQ
v_slash_k_grad_thread_buf
.
Clear
();
// gemm2_b_thread_buf.Clear();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dK
// load KGrad Gemm A
const
auto
sgrad_slice_idx
=
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
sgrad_slice_idx
[
I2
],
sgrad_slice_idx
[
I2
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I2
));
constexpr
auto
nwave_range
=
make_tuple
(
sgrad_slice_idx
[
I3
],
sgrad_slice_idx
[
I3
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I3
));
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
.
Run
(
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
sgrad_slice_idx
[
I0
],
sgrad_slice_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
gemm2_a_block_buf
);
}
kgrad_gemm_tile_q_blockwise_copy
.
Run
(
Gemm2
::
b_block_desc_o0_o1_o2_m0_m1_m2_m3
,
q_block_buf
,
Gemm2
::
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
gemm2_b_thread_buf
);
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
Gemm2
::
b_block_desc_o0_o1_o2_m0_m1_m2_m3
,
Gemm2
::
b_block_slice_copy_step
);
block_sync_lds
();
// sync before read
v_slash_k_grad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_thread_buf
,
v_slash_k_grad_thread_buf
);
});
// end gemm dK
// atomic_add dK
kgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
v_slash_k_grad_thread_buf
,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_grid_buf
);
// move slice window
gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_k0_n_k1
,
GemmBlockwiseCopy
::
gemm_tile_k_block_slice_copy_step
);
// step N
gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
v_grid_desc_o0_n_o1
,
GemmBlockwiseCopy
::
gemm_tile_v_block_slice_copy_step
);
// step N
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
Gemm2
::
b_block_desc_o0_o1_o2_m0_m1_m2_m3
,
Gemm2
::
b_block_reset_copy_step
);
// rewind M
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
Gemm1
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
Gemm1
::
b_block_reset_copy_step
);
// rewind K
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
Gemm2
::
b_block_desc_o0_o1_o2_m0_m1_m2_m3
,
Gemm2
::
b_block_reset_copy_step
);
// rewind M
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle dQ and write
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
Gemm1NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
qgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
qgrad_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I5
);
constexpr
auto
N3
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I6
);
constexpr
auto
N4
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
)),
// M2 = MPerXdl
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
,
// N2 * N3 * N4 = NPerXdl
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
qgrad_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
SElementwiseOperation
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I4
]),
scale_rp_dropout
};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
Gemm1NXdlPerWave
,
1
,
1
,
1
,
N2
,
1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
1
,
N2
,
1
,
N4
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
Gemm1NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
qgrad_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
qgrad_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
c_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/utility/generic_memory_space_atomic.hpp
View file @
20fc6679
...
...
@@ -71,6 +71,75 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
return
vy
.
template
AsType
<
double2_t
>()[
I0
];
}
inline
__host__
__device__
half2_t
add_fp16x2_t
(
const
half2_t
&
a
,
const
half2_t
&
b
)
{
half2_t
rtn
;
rtn
[
0
]
=
a
[
0
]
+
b
[
0
];
rtn
[
1
]
=
a
[
1
]
+
b
[
1
];
return
rtn
;
}
template
<
>
__device__
half2_t
atomic_add
<
half2_t
>
(
half2_t
*
p_dst
,
const
half2_t
&
x
)
{
uint32_t
*
dword_addr
=
reinterpret_cast
<
uint32_t
*>
(
p_dst
);
uint32_t
cur_v
=
*
dword_addr
;
uint32_t
old_v
,
new_v
;
do
{
old_v
=
cur_v
;
half2_t
new_
=
add_fp16x2_t
(
*
reinterpret_cast
<
half2_t
*>
(
&
cur_v
),
x
);
new_v
=
*
reinterpret_cast
<
uint32_t
*>
(
&
new_
);
cur_v
=
atomicCAS
(
dword_addr
,
old_v
,
new_v
);
}
while
(
cur_v
!=
old_v
);
return
x
;
}
// union U16BF16 {
// uint16_t u16;
// bhalf_t bf16;
// };
// inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){
// U16BF16 xa {.bf16 = a};
// U16BF16 xb {.bf16 = b};
// U16BF16 xr;
// xr.u16 = xa.u16 + xb.u16;
// return xr.bf16;
// }
inline
__host__
__device__
bhalf_t
add_bf16_t
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
)
{
return
type_convert
<
bhalf_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
}
inline
__host__
__device__
bhalf2_t
add_bf16x2_t
(
const
bhalf2_t
&
a
,
const
bhalf2_t
&
b
)
{
bhalf2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
template
<
>
__device__
bhalf2_t
atomic_add
<
bhalf2_t
>
(
bhalf2_t
*
p_dst
,
const
bhalf2_t
&
x
)
{
uint32_t
*
dword_addr
=
reinterpret_cast
<
uint32_t
*>
(
p_dst
);
uint32_t
cur_v
=
*
dword_addr
;
uint32_t
old_v
,
new_v
;
do
{
old_v
=
cur_v
;
bhalf2_t
new_
=
add_bf16x2_t
(
*
reinterpret_cast
<
bhalf2_t
*>
(
&
cur_v
),
x
);
new_v
=
*
reinterpret_cast
<
uint32_t
*>
(
&
new_
);
cur_v
=
atomicCAS
(
dword_addr
,
old_v
,
new_v
);
}
while
(
cur_v
!=
old_v
);
return
x
;
}
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment