Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
66052232
Commit
66052232
authored
Feb 13, 2023
by
danyao12
Browse files
sync attn-bwd-dropout
parents
5eb5e316
bf80ceee
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1814 additions
and
3219 deletions
+1814
-3219
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+4
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+153
-26
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
...emm/batched_multihead_attention_backward_fp16_dropout.cpp
+0
-808
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+5
-0
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+20
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+105
-11
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle
...l/device_grouped_multihead_attention_forward_xdl_cshuffle
+1058
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+257
-51
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+0
-2316
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+168
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+36
-0
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+3
-6
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
66052232
...
@@ -10,8 +10,12 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
...
@@ -10,8 +10,12 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable
(
example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
<<<<<<< HEAD
add_example_executable
(
example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp
)
=======
>>>>>>> attn-bwd-dropout
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
66052232
...
@@ -43,23 +43,27 @@ Kernel outputs:
...
@@ -43,23 +43,27 @@ Kernel outputs:
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
QKVElementOp
=
PassThrough
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
VElementOp
=
Scale
;
using
DataType
=
F16
;
using
DataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -91,6 +95,7 @@ using DeviceGemmInstance =
...
@@ -91,6 +95,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
DataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
...
@@ -182,12 +187,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
...
@@ -182,12 +187,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>
;
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
typename
TensorV
,
typename
TensorV
,
typename
TensorS
,
typename
TensorS
,
typename
TensorP
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
...
@@ -197,7 +206,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -197,7 +206,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorS
&
s_g_m_n
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
)
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
...
@@ -225,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -225,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// Y = P * V
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
p_
drop_
g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
}
...
@@ -256,6 +276,13 @@ int run(int argc, char* argv[])
...
@@ -256,6 +276,13 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
// use default case
// use default case
...
@@ -321,6 +348,11 @@ int run(int argc, char* argv[])
...
@@ -321,6 +348,11 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
...
@@ -332,6 +364,7 @@ int run(int argc, char* argv[])
...
@@ -332,6 +364,7 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
...
@@ -339,10 +372,12 @@ int run(int argc, char* argv[])
...
@@ -339,10 +372,12 @@ int run(int argc, char* argv[])
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ks: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
...
@@ -408,9 +443,11 @@ int run(int argc, char* argv[])
...
@@ -408,9 +443,11 @@ int run(int argc, char* argv[])
// calculate y & log-sum-exp beforehand
// calculate y & log-sum-exp beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
...
@@ -418,12 +455,25 @@ int run(int argc, char* argv[])
...
@@ -418,12 +455,25 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
(
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
...
@@ -433,6 +483,7 @@ int run(int argc, char* argv[])
...
@@ -433,6 +483,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
...
@@ -443,6 +494,7 @@ int run(int argc, char* argv[])
...
@@ -443,6 +494,7 @@ int run(int argc, char* argv[])
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
...
@@ -452,9 +504,12 @@ int run(int argc, char* argv[])
...
@@ -452,9 +504,12 @@ int run(int argc, char* argv[])
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
// get z matrix
{
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
...
@@ -468,6 +523,8 @@ int run(int argc, char* argv[])
...
@@ -468,6 +523,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
...
@@ -481,7 +538,9 @@ int run(int argc, char* argv[])
...
@@ -481,7 +538,9 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
Scale
{
alpha
},
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{});
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -489,7 +548,46 @@ int run(int argc, char* argv[])
...
@@ -489,7 +548,46 @@ int run(int argc, char* argv[])
return
0
;
return
0
;
}
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
}
// not need output z matrix
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
// 5 GEMM ops in total:
// 5 GEMM ops in total:
...
@@ -511,9 +609,32 @@ int run(int argc, char* argv[])
...
@@ -511,9 +609,32 @@ int run(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
// run fowad again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
// call kernel again
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
...
@@ -523,6 +644,7 @@ int run(int argc, char* argv[])
...
@@ -523,6 +644,7 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
...
@@ -544,18 +666,24 @@ int run(int argc, char* argv[])
...
@@ -544,18 +666,24 @@ int run(int argc, char* argv[])
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
// dP = dY * V^T
// dP
_dropout
= dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_
drop_
g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_
o
;
std
::
cout
<<
"ygrad_
drop_
g_m_o ref:
\n
"
<<
ygrad_
drop_
g_m_
n
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"pgrad_
drop_
g_m_n ref:
\n
"
<<
pgrad_
drop_
g_m_n
;
}
}
#endif
#endif
// dP = dP_dropout x Z
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
@@ -578,15 +706,14 @@ int run(int argc, char* argv[])
...
@@ -578,15 +706,14 @@ int run(int argc, char* argv[])
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
}
}
#endif
#endif
// dV = P_drop^T * dY
// dV = P^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
p_g_n_m
=
p_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
p_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
p_
drop_
g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
0
f
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_g_n_m ref:
\n
"
<<
p_g_n_m
;
std
::
cout
<<
"p_
drop_
g_n_m ref:
\n
"
<<
p_
drop_
g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
}
}
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
deleted
100644 → 0
View file @
5eb5e316
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Computation graph:
K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Kernel inputs:
Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
*/
#define PRINT_HOST 0
#define USING_MASK 1
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_train_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
VElementOp
=
Scale
;
using
DataType
=
F16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskOutUpperTriangle
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
#endif
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
DataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
DataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
// Ref Gemm for backward pass
// fp16 in, fp16 out
using
ReferenceGemmGradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
DataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
#if USING_MASK
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
#endif
// P = Softmax(S)
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_drop_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
2
;
// method 1 will have slightly higher error; TODO: to investigate
bool
time_kernel
=
true
;
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
128
;
ck
::
index_t
O
=
128
;
ck
::
index_t
G0
=
3
;
ck
::
index_t
G1
=
2
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ks: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
2
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
// calculate y & log-sum-exp beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
// z_device_buf.SetZero();
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
// 5 GEMM ops in total:
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// 3x MNK + 2x MNO
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std
::
size_t
num_btype
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
size_t
(
2
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
std
::
ofstream
file
(
"./z_matrix_txt"
);
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
file
<<
z_g_m_n
<<
std
::
endl
;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
if
(
do_verification
)
{
// run fowad again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
//
// call kernel again
//
// example set Z matrix to null, will not ouput z matrix data
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
#if PRINT_HOST
{
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"v_g_n_o ref:
\n
"
<<
v_g_n_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
}
#endif
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
// dP_dropout = dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_drop_g_m_o ref:
\n
"
<<
ygrad_drop_g_m_n
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_drop_g_m_n ref:
\n
"
<<
pgrad_drop_g_m_n
;
}
#endif
// dP = dP_dropout x Z
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_o
(
idx_gmo
);
}
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
});
#if PRINT_HOST
{
std
::
cout
<<
"===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
\n
"
;
std
::
cout
<<
"p_g_m_n ref:
\n
"
<<
p_g_m_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"y_g_m_o ref:
\n
"
<<
y_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
}
#endif
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_drop_g_n_m ref:
\n
"
<<
p_drop_g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
}
#endif
// dQ = alpha * dS * K
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"qgrad_g_m_k ref:
\n
"
<<
qgrad_g_m_k
;
}
#endif
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
std
::
cout
<<
"sgrad_g_n_m ref:
\n
"
<<
sgrad_g_n_m
;
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"kgrad_g_n_k ref:
\n
"
<<
kgrad_g_n_k
;
}
#endif
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
vgrad_device_buf
.
FromDevice
(
vgrad_gs_os_ns_device_result
.
mData
.
data
());
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
100755 → 100644
View file @
66052232
...
@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
...
@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -42,6 +43,7 @@ using B1DataType = F16;
...
@@ -42,6 +43,7 @@ using B1DataType = F16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -69,6 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -69,6 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -78,6 +81,7 @@ using DeviceGemmInstance =
...
@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
...
@@ -159,4 +163,5 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -159,4 +163,5 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_grouped_multihead_attention_forward.inc"
#include "run_grouped_multihead_attention_forward.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
66052232
...
@@ -48,6 +48,7 @@ int run(int argc, char* argv[])
...
@@ -48,6 +48,7 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_z
;
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
...
@@ -55,6 +56,7 @@ int run(int argc, char* argv[])
...
@@ -55,6 +56,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
...
@@ -62,6 +64,7 @@ int run(int argc, char* argv[])
...
@@ -62,6 +64,7 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
b0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b1_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b1_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
c_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
c_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
z_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
lse_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
lse_tensors_device
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
...
@@ -102,6 +105,12 @@ int run(int argc, char* argv[])
...
@@ -102,6 +105,12 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
...
@@ -114,6 +123,8 @@ int run(int argc, char* argv[])
...
@@ -114,6 +123,8 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
c_gs_ms_os_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
lse_gs_ms_strides
,
{},
// acc0_biases_gs_ms_ns_lengths
{},
// acc0_biases_gs_ms_ns_lengths
...
@@ -126,6 +137,7 @@ int run(int argc, char* argv[])
...
@@ -126,6 +137,7 @@ int run(int argc, char* argv[])
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
;
...
@@ -140,10 +152,13 @@ int run(int argc, char* argv[])
...
@@ -140,10 +152,13 @@ int run(int argc, char* argv[])
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
<<
std
::
endl
;
}
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
...
@@ -172,6 +187,7 @@ int run(int argc, char* argv[])
...
@@ -172,6 +187,7 @@ int run(int argc, char* argv[])
b0_tensors
.
push_back
(
b0_gs_ns_ks
);
b0_tensors
.
push_back
(
b0_gs_ns_ks
);
b1_tensors
.
push_back
(
b1_gs_os_ns
);
b1_tensors
.
push_back
(
b1_gs_os_ns
);
c_tensors
.
push_back
(
c_gs_ms_os_device_result
);
c_tensors
.
push_back
(
c_gs_ms_os_device_result
);
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms_device_result
);
lse_tensors
.
push_back
(
lse_gs_ms_device_result
);
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
...
@@ -182,6 +198,8 @@ int run(int argc, char* argv[])
...
@@ -182,6 +198,8 @@ int run(int argc, char* argv[])
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
()));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
()));
z_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms_device_result
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
LSEDataType
)
*
lse_gs_ms_device_result
.
mDesc
.
GetElementSpaceSize
()));
...
@@ -193,6 +211,7 @@ int run(int argc, char* argv[])
...
@@ -193,6 +211,7 @@ int run(int argc, char* argv[])
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_z
.
push_back
(
z_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
}
}
...
@@ -209,6 +228,7 @@ int run(int argc, char* argv[])
...
@@ -209,6 +228,7 @@ int run(int argc, char* argv[])
p_b0
,
p_b0
,
p_b1
,
p_b1
,
p_c
,
p_c
,
p_z
,
p_lse
,
p_lse
,
{},
// p_acc0_biases
{},
// p_acc0_biases
{},
// p_acc1_biases
{},
// p_acc1_biases
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
66052232
...
@@ -79,6 +79,7 @@ template <index_t NumDimG,
...
@@ -79,6 +79,7 @@ template <index_t NumDimG,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
Acc1BiasDataType
,
...
@@ -104,6 +105,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
...
@@ -104,6 +105,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
z_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
z_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
...
@@ -119,6 +123,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
...
@@ -119,6 +123,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
View file @
66052232
...
@@ -29,6 +29,7 @@ namespace device {
...
@@ -29,6 +29,7 @@ namespace device {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
DataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -37,6 +38,7 @@ template <typename GridwiseGemm,
...
@@ -37,6 +38,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
LSEGridDescriptor_M
,
...
@@ -50,9 +52,10 @@ __global__ void
...
@@ -50,9 +52,10 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v
1
(
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v
2
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
...
@@ -67,6 +70,8 @@ __global__ void
...
@@ -67,6 +70,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -76,7 +81,10 @@ __global__ void
...
@@ -76,7 +81,10 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_dropout
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -90,6 +98,8 @@ __global__ void
...
@@ -90,6 +98,8 @@ __global__ void
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
...
@@ -97,8 +107,13 @@ __global__ void
...
@@ -97,8 +107,13 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
...
@@ -114,13 +129,16 @@ __global__ void
...
@@ -114,13 +129,16 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
);
c0_matrix_mask
,
p_dropout
,
ph
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -151,6 +169,7 @@ template <index_t NumDimG,
...
@@ -151,6 +169,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
DataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
Acc1BiasDataType
,
...
@@ -429,6 +448,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -429,6 +448,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
}
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides_vec
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths_vec
,
z_gs_ms_ns_strides_vec
);
}
//
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
...
@@ -489,9 +514,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -489,9 +514,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
VGradGridDesc_N_O
=
decltype
(
MakeVGradGridDescriptor_N_O
({},
{}));
using
VGradGridDesc_N_O
=
decltype
(
MakeVGradGridDescriptor_N_O
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
constexpr
static
auto
make_MaskOutPredicate
()
{
{
...
@@ -510,11 +537,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -510,11 +537,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
{
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
BatchStrideLSE_
(
BatchStrideLSE
)
BatchStrideLSE_
(
BatchStrideLSE
)
...
@@ -531,6 +560,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -531,6 +560,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
{
return
b1_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
b1_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
@@ -549,13 +583,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -549,13 +583,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
private:
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
index_t
BatchStrideLSE_
;
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
<
DataType
,
// TODO: distinguish A/B datatype
DataType
,
// TODO: distinguish A/B datatype
LSEDataType
,
LSEDataType
,
GemmAccDataType
,
GemmAccDataType
,
...
@@ -568,6 +604,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -568,6 +604,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
YGridDesc_M_O
,
LSEGridDesc_M
,
LSEGridDesc_M
,
...
@@ -624,6 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -624,6 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
Argument
(
Argument
(
const
DataType
*
p_a_grid
,
const
DataType
*
p_a_grid
,
const
DataType
*
p_b_grid
,
const
DataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
const
DataType
*
p_b1_grid
,
const
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
LSEDataType
*
p_lse_grid
,
...
@@ -637,6 +675,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -637,6 +675,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
...
@@ -652,9 +692,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -652,9 +692,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_z_grid_
{
p_z_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
p_lse_grid_
{
p_lse_grid
},
p_lse_grid_
{
p_lse_grid
},
...
@@ -666,6 +709,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -666,6 +709,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
...
@@ -683,6 +727,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -683,6 +727,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
y_grid_desc_m_o_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
y_grid_desc_m_o_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -707,6 +753,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -707,6 +753,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
compute_base_ptr_of_batch_
{
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b_grid_desc_g_n_k_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
...
@@ -729,6 +776,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -729,6 +776,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
y_grid_desc_m_o_
);
y_grid_desc_m_o_
);
}
}
p_dropout_
=
1.
f
-
p_drop
;
float
rp_dropout_
=
1.
f
/
p_dropout_
;
acc_element_op_
.
Append
(
rp_dropout_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
// Print();
// Print();
}
}
...
@@ -760,6 +817,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -760,6 +817,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
...
@@ -771,6 +829,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -771,6 +829,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
@@ -782,9 +841,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -782,9 +841,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -807,6 +870,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -807,6 +870,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_dropout_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
};
};
// Invoker
// Invoker
...
@@ -831,9 +898,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -831,9 +898,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v
1
<
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v
2
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
DataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -842,6 +910,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -842,6 +910,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
LSEGridDesc_M
,
...
@@ -859,6 +928,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -859,6 +928,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_lse_grid_
,
...
@@ -873,6 +943,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -873,6 +943,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
...
@@ -881,7 +952,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -881,7 +952,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
);
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_
,
arg
.
seed_
,
arg
.
offset_
);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
@@ -992,6 +1066,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -992,6 +1066,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
DataType
*
p_a
,
const
DataType
*
p_a
,
const
DataType
*
p_b
,
const
DataType
*
p_b
,
ZDataType
*
p_z
,
const
DataType
*
p_b1
,
const
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
LSEDataType
*
p_lse
,
...
@@ -1005,6 +1080,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1005,6 +1080,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
...
@@ -1020,10 +1097,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1020,10 +1097,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
p_z
,
p_b1
,
p_b1
,
p_c
,
p_c
,
p_lse
,
p_lse
,
...
@@ -1037,6 +1117,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1037,6 +1117,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
...
@@ -1050,7 +1132,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1050,7 +1132,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
};
c_element_op
,
p_drop
,
seeds
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -1060,6 +1144,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1060,6 +1144,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
void
*
p_z
,
const
void
*
p_b1
,
const
void
*
p_b1
,
const
void
*
p_c
,
const
void
*
p_c
,
const
void
*
p_lse
,
const
void
*
p_lse
,
...
@@ -1073,6 +1158,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1073,6 +1158,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
...
@@ -1088,10 +1175,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1088,10 +1175,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
// override
CElementwiseOperation
c_element_op
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
...
@@ -1105,6 +1195,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1105,6 +1195,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
...
@@ -1118,7 +1210,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1118,7 +1210,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
);
c_element_op
,
p_drop
,
seeds
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_
batch
ed_multihead_attention_
back
ward_
train_
xdl_cshuffle
.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_
group
ed_multihead_attention_
for
ward_xdl_cshuffle
View file @
66052232
...
@@ -10,153 +10,130 @@
...
@@ -10,153 +10,130 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
//
#include "ck/tensor_operation/gpu/device/
device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/
tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_
bas
e.hpp"
#include "ck/tensor_operation/gpu/device/device_
grouped_gemm_softmax_gemm_permut
e.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace ck {
namespace tensor_operation {
namespace tensor_operation {
namespace device {
namespace device {
template <typename GridwiseGemm,
template <typename GridwiseGemm,
typename
DataType
,
typename GemmAccDataType,
typename
ZDataType
,
typename GroupKernelArg,
typename
LSEDataType
,
typename AElementwiseOperation,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename CElementwiseOperation,
typename
AGridDesc_AK0_M_AK1
,
bool HasMainKBlockLoop,
typename
BGridDesc_BK0_N_BK1
,
bool IsDropout>
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_M0_O_M1
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
__global__ void
__global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2
(
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
const
DataType
*
__restrict__
p_a_grid
,
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const
DataType
*
__restrict__
p_b_grid
,
const index_t group_count,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
const AElementwiseOperation a_element_op,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const CElementwiseOperation c_element_op,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const ushort p_dropout_in_16bits,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const GemmAccDataType p_dropout_rescale,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
VGradGridDescriptor_N_O
vgrad_grid_desc_n_o
,
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_dropout
,
const unsigned long long seed,
const unsigned long long seed,
const unsigned long long offset)
const unsigned long long offset)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
const index_t block_id = get_block_1d_id();
// offsets
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args));
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while(
(!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_)))
{
if(block_id < arg_ptr[group_id].block_start_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
// per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
static_cast<long_index_t>(
arg_ptr[group_id].
compute_base_ptr_of_batch
_
.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
static_cast<long_index_t>(
arg_ptr[group_id].
compute_base_ptr_of_batch
_
.GetCBasePtr(g_idx)));
const
long_index_t
lse
_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const long_index_t
z
_batch_offset
= __builtin_amdgcn_readfirstlane(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
Get
LSE
BasePtr
(
g_idx
)));
static_cast<long_index_t>(
arg_ptr[group_id].
compute_base_ptr_of_batch
_
.Get
Z
BasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
const
index_t
global_thread_id
=
get_thread_global_1d_id
(
);
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
//unsigned short* p_z_grid_in = //
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_
a
_grid
+
a
_batch_offset
,
// : arg_ptr[group_id].
p_
z
_grid
_
+
z
_batch_offset
);
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_b1
_grid
+
b1
_batch_offset
,
arg_ptr[group_id].p_a
_grid
_
+
a
_batch_offset,
p_
c
_grid
+
c
_batch_offset
,
arg_ptr[group_id].
p_
b
_grid
_
+
b
_batch_offset,
p_lse
_grid
+
lse
_batch_offset
,
arg_ptr[group_id].p_b1
_grid
_
+
b1
_batch_offset,
p_ygrad
_grid
+
c_batch_offset
,
arg_ptr[group_id].p_c
_grid
_
+ c_batch_offset,
p_qgrad_grid
+
a_batch_offset
,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
p_kgrad
_grid
+
b
_batch_offset
,
: arg_ptr[group_id].p_z
_grid
_
+
z
_batch_offset,
p_vgrad
_grid
+
b1
_batch_offset
,
arg_ptr[group_id].p_lse
_grid
_
+
lse
_batch_offset,
p_shared,
p_shared,
a_element_op,
a_element_op,
b_element_op,
b_element_op,
acc_element_op,
acc_element_op,
b1_element_op,
b1_element_op,
c_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1
,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1
,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1
,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, ////////
lse_grid_desc_m
,
arg_ptr[group_id].lse_grid_desc_m_,
vgrad_grid_desc_n_o
,
arg_ptr[group_id].block_2_ctile_map_,
ygrad_grid_desc_m0_o_m1
,
arg_ptr[group_id].c0_matrix_mask_,
block_2_ctile_map
,
p_dropout_in_16bits,
c0_matrix_mask
,
p_dropout_rescale,
p_dropout
,
ph);
ph);
#else
#else
ignore
=
p_a_grid
;
ignore = group_kernel_args;
ignore
=
p_b_grid
;
ignore = group_count;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore = a_element_op;
ignore = a_element_op;
ignore = b_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = c_element_op;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -168,7 +145,10 @@ template <index_t NumDimG,
...
@@ -168,7 +145,10 @@ template <index_t NumDimG,
index_t NumDimN,
index_t NumDimN,
index_t NumDimK,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
index_t NumDimO, // NumDimGemm1N
typename
DataType
,
typename ADataType,
typename BDataType,
typename B1DataType,
typename CDataType,
typename ZDataType,
typename ZDataType,
typename LSEDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc0BiasDataType,
...
@@ -227,8 +207,26 @@ template <index_t NumDimG,
...
@@ -227,8 +207,26 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
LoopScheduler LoopSched = LoopScheduler::Default>
struct
DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
: public DeviceGroupedGemmSoftmaxGemmPermuteTrain<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
"Number of dimension must be greater than 0");
...
@@ -236,11 +234,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -236,11 +234,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO: implement bias combination
// TODO
ANT
: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
#if 0
// TODO: use alias
// TODO
ANT
: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm0K = NumDimK;
...
@@ -249,28 +247,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -249,28 +247,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
;
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermuteTrain<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>::ProblemDesc;
static constexpr auto I0 = Number<0>{};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I2 = Number<2>{};
static
constexpr
index_t
Q_K1
=
8
;
static
constexpr
index_t
K_K1
=
8
;
static
constexpr
index_t
V_N1
=
2
;
static
constexpr
index_t
Q_M1
=
2
;
static
constexpr
index_t
K_N1
=
2
;
static
constexpr
index_t
V_O1
=
8
;
static
constexpr
index_t
Y_O1
=
8
;
static
constexpr
index_t
Y_M1
=
2
;
static
constexpr
auto
padder
=
GemmGemmPadder
<
GemmSpec
,
Number
<
MPerBlock
>
,
Number
<
NPerBlock
>
,
Number
<
KPerBlock
>
,
Number
<
Gemm1NPerBlock
>>
{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
...
@@ -280,18 +281,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -280,18 +281,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
B1Spec,
B1Spec,
CSpec>;
CSpec>;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
{
...
@@ -300,7 +289,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -300,7 +289,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
Number<AK1>{});
Number<AK1>{});
}
}
// K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
{
...
@@ -309,7 +297,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -309,7 +297,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
Number<BK1>{});
Number<BK1>{});
}
}
// V in Gemm B1 position
static auto
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
...
@@ -320,165 +307,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -320,165 +307,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
Number<B1K1>{});
Number<B1K1>{});
}
}
//
// dV = P^T * dY
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides_vec
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
std
::
vector
<
index_t
>
gs_ids
(
NumDimG
);
std
::
iota
(
gs_ids
.
begin
(),
gs_ids
.
end
(),
0
);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std
::
vector
<
index_t
>
os_ids
(
NumDimO
);
std
::
iota
(
os_ids
.
begin
(),
os_ids
.
end
(),
NumDimG
);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std
::
vector
<
index_t
>
ns_ids
(
NumDimN
);
std
::
iota
(
ns_ids
.
begin
(),
ns_ids
.
end
(),
NumDimG
+
NumDimO
);
std
::
vector
<
index_t
>
ids_old2new
;
ids_old2new
.
insert
(
ids_old2new
.
end
(),
gs_ids
.
begin
(),
gs_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths_vec
(
num_dims
),
v_gs_ns_os_strides_vec
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths_vec
[
i
]
=
v_gs_os_ns_lengths_vec
[
id_new
];
v_gs_ns_os_strides_vec
[
i
]
=
v_gs_os_ns_strides_vec
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
}
template
<
typename
YGridDesc_M_O
>
static
auto
MakeYGradGridDescriptor_M0_O_M1
(
const
YGridDesc_M_O
&
ygrad_grid_desc_m_o
)
{
const
auto
M
=
ygrad_grid_desc_m_o
.
GetLength
(
I0
);
const
auto
O
=
ygrad_grid_desc_m_o
.
GetLength
(
I1
);
const
auto
Y_M0
=
M
/
Y_M1
;
return
transform_tensor_descriptor
(
ygrad_grid_desc_m_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_M0
,
Y_M1
)),
make_pass_through_transform
(
O
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides_vec
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths_vec
,
y_gs_ms_os_strides_vec
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides_vec
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
std
::
vector
<
index_t
>
gs_ids
(
NumDimG
);
std
::
iota
(
gs_ids
.
begin
(),
gs_ids
.
end
(),
0
);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std
::
vector
<
index_t
>
os_ids
(
NumDimO
);
std
::
iota
(
os_ids
.
begin
(),
os_ids
.
end
(),
NumDimG
);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std
::
vector
<
index_t
>
ns_ids
(
NumDimN
);
std
::
iota
(
ns_ids
.
begin
(),
ns_ids
.
end
(),
NumDimG
+
NumDimO
);
std
::
vector
<
index_t
>
ids_old2new
;
ids_old2new
.
insert
(
ids_old2new
.
end
(),
gs_ids
.
begin
(),
gs_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths_vec
(
num_dims
),
v_gs_ns_os_strides_vec
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths_vec
[
i
]
=
v_gs_os_ns_lengths_vec
[
id_new
];
v_gs_ns_os_strides_vec
[
i
]
=
v_gs_os_ns_strides_vec
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
// N_O to O0_N_O1; to refactor
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides_vec
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths_vec
,
q_gs_ms_ks_strides_vec
);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides_vec
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths_vec
,
k_gs_ns_ks_strides_vec
);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
static auto MakeLSEGridDescriptor_M(index_t MRaw)
{
{
...
@@ -508,18 +341,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -508,18 +341,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using
Y
GridDesc_M_
O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
C
GridDesc_M_
N
= decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using
VGradGridDesc_N_O
=
decltype
(
MakeVGradGridDescriptor_N_O
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
constexpr static auto make_MaskOutPredicate()
constexpr static auto make_MaskOutPredicate()
{
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
...
@@ -537,15 +368,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -537,15 +368,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
{
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE)
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
BatchStrideLSE_(BatchStrideLSE)
{
{
}
}
...
@@ -560,11 +391,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -560,11 +391,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
{
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...
@@ -575,6 +401,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -575,6 +401,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
{
{
return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
...
@@ -583,19 +414,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -583,19 +414,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
private:
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
index_t BatchStrideLSE_;
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle<
DataType
,
// TODO: distinguish A/B datatype
ADataType, // TODO: distinguish A/B datatype
LSEDataType
,
GemmAccDataType,
GemmAccDataType,
CShuffleDataType,
CShuffleDataType,
CDataType,
LSEDataType,
AElementwiseOperation,
AElementwiseOperation,
BElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
AccElementwiseOperation,
...
@@ -604,9 +435,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -604,9 +435,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
BGridDesc_BK0_N_BK1,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
YGridDesc_M_O
,
CGridDesc_M_N,
ZGridDesc_M_N,
LSEGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage,
NumGemmKPrefetchStage,
BlockSize,
BlockSize,
...
@@ -655,225 +486,242 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -655,225 +486,242 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
Transform::matrix_padder.PadN,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
struct GroupKernelArg
{
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
ZDataType* p_z_grid_;
LSEDataType* p_lse_grid_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_;
// batch & stride
index_t num_blocks_per_batch_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// block-to-c-tile map
Block2CTileMap block_2_ctile_map_;
index_t block_start_, block_end_;
};
struct GroupDeviceArg
{
// lengths for the last dimensions of overall problem for sanity check of vector load/store
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
// strides for the last dimensions of each tensor for sanity check of vector load/store
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_;
};
// Argument
// Argument
// FIXME: constness
struct Argument : public BaseArgument
struct Argument : public BaseArgument
{
{
Argument
(
Argument(std::vector<const void*> p_a_vec,
const
DataType
*
p_a_grid
,
std::vector<const void*> p_b_vec,
const
DataType
*
p_b_grid
,
std::vector<const void*> p_b1_vec,
ZDataType
*
p_z_grid
,
std::vector<void*> p_c_vec,
const
DataType
*
p_b1_grid
,
std::vector<void*> p_z_vec,
const
DataType
*
p_c_grid
,
// for dS
std::vector<void*> p_lse_vec,
const
LSEDataType
*
p_lse_grid
,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
const
DataType
*
p_ygrad_grid
,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
DataType
*
p_qgrad_grid
,
std::vector<ProblemDesc> problem_desc_vec,
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
CElementwiseOperation c_element_op,
float
p_drop
,
float p_drop
out,
std::tuple<unsigned long long, unsigned long long> seeds)
std::tuple<unsigned long long, unsigned long long> seeds)
:
p_a_grid_
{
p_a_grid
},
: a_element_op_{a_element_op},
p_b_grid_
{
p_b_grid
},
p_z_grid_
{
p_z_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_lse_grid_
{
p_lse_grid
},
p_ygrad_grid_
{
p_ygrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
ygrad_grid_desc_m0_o_m1_
{
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
y_grid_desc_m_o_
)},
// batch offsets
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
b1_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
y_grid_desc_m_o_
)},
a_element_op_
{
a_element_op
},
b_element_op_{b_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
b1_element_op_{b1_element_op},
c_element_op_
{
c_element_op
},
c_element_op_{c_element_op}
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
b1_gs_gemm1ns_gemm1ks_lengths
[
NumDimG
+
NumDimO
-
1
]},
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
b_nz_kz_strides_
{
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
]},
b1_nz_kz_strides_
{
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
-
1
],
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
{
{
// TODO: implement bias addition
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
group_count_ = problem_desc_vec.size();
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ns_lengths
;
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
ignore
=
acc0_biases_gs_ms_ns_strides
;
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size()))
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
y_grid_desc_m_o_
,
block_2_ctile_map_
))
{
{
y_grid_desc_mblock_mperblock_oblock_operblock_
=
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
y_grid_desc_m_o_
);
}
}
p_dropout_
=
1.
f
-
p_drop
;
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
float
rp_dropout_
=
1.
f
/
p_dropout_
;
{
acc_element_op_
.
Append
(
rp_dropout_
);
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
seed_
=
std
::
get
<
0
>
(
seeds
);
grid_size_ = 0;
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
for(std::size_t i = 0; i < group_count_; i++)
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
{
// Print();
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto lse_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n);
//typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k,
b_grid_desc_g_n_k,
b1_grid_desc_g_n_k,
c_grid_desc_g_m_n,
z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
}
void
Print
()
const
group_kernel_args_.push_back({p_a_grid,
{
p_b_grid,
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
p_b1_grid,
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
p_c_grid,
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
p_z_grid,
// a_grid_desc_g_m_k_.Print();
p_lse_grid,
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
a_grid_desc_ak0_m_ak1,
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
b_grid_desc_bk0_n_bk1,
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
b1_grid_desc_bk0_n_bk1,
// b_grid_desc_g_n_k_.Print();
c_grid_desc_mblock_mperblock_nblock_nperblock,
std
::
cout
<<
"b1_grid_desc_g_o_n_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
z_grid_desc_m_n,
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
lse_grid_desc_m,
// b1_grid_desc_g_n_k_.Print();
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
std
::
cout
<<
"c_grid_desc_g_m_o_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
compute_base_ptr_of_batch,
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
c0_matrix_mask,
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
block_2_ctile_map,
// c_grid_desc_g_m_n_.Print();
BlockStart,
std
::
cout
<<
"vgrad_grid_desc_n_o_: "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I0
)
<<
", "
BlockEnd});
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
group_device_args_.push_back(
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
{problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
{problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
{problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
}
}
// pointers
is_dropout_ = p_dropout > 0.0; //
const
DataType
*
p_a_grid_
;
p_dropout_ = 1.f - p_dropout;
const
DataType
*
p_b_grid_
;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
ZDataType
*
p_z_grid_
;
p_dropout_ = 1.f / p_dropout_;
const
DataType
*
p_b1_grid_
;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
const
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
VGradGridDesc_N_O
vgrad_grid_desc_n_o_
;
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1_
;
// batch offsets
seed_ = std::get<0>(seeds);
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
offset_ = std::get<1>(seeds);
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
}
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
std::vector<GroupKernelArg> group_kernel_args_;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
_
;
std::vector<GroupDeviceArg> group_device_args
_;
// block-to-c-tile map
std::size_t group_count_;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t grid_size_;
// element-wise op
AElementwiseOperation a_element_op_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
CElementwiseOperation c_element_op_;
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_mz_nz_kz_gemm1nz_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
std
::
vector
<
index_t
>
b_nz_kz_strides_
;
std
::
vector
<
index_t
>
b1_nz_kz_strides_
;
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float p_dropout_;
float p_dropout_;
ushort p_dropout_in_16bits_;
unsigned long long seed_;
unsigned long long seed_;
unsigned long long offset_;
unsigned long long offset_;
GemmAccDataType p_dropout_rescale_;
bool is_dropout_;
};
};
// Invoker
// Invoker
...
@@ -888,88 +736,90 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -888,88 +736,90 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
throw std::runtime_error("wrong! unsupported argument");
throw std::runtime_error("wrong! unsupported argument");
}
}
const
index_t
grid_size
=
bool all_has_main_k_block_loop = true;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
)
*
arg
.
batch_count_
;
bool some_has_main_k_block_loop = false;
for(std::size_t i = 0; i < arg.group_count_; i++)
{
const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
all_has_main_k_block_loop &= y;
some_has_main_k_block_loop |= y;
}
// Gemm0_K
hipGetErrorString(hipMemcpy(arg.p_workspace_,
const
auto
K
=
arg.group_kernel_args_.data(),
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
float ave_time = 0;
float ave_time = 0;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2
<
const auto kernel =
GridwiseGemm
,
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
DataType
,
GemmAccDataType,
ZDataType
,
GroupKernelArg,
LSEDataType
,
AElementwiseOperation,
AElementwiseOperation,
BElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
CElementwiseOperation,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
has_main_k_block_loop_,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
is_dropout_>;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
return launch_and_time_kernel(
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
stream_config,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
VGradGridDesc_N_O
,
DeviceOp
::
YGradGridDesc_M0_O_M1
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel,
kernel,
dim3
(
grid_size
),
dim3(
arg.
grid_size
_
),
dim3(BlockSize),
dim3(BlockSize),
0,
0,
arg
.
p_a_grid_
,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg
.
p_b_grid_
,
arg.group_count_,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg.a_element_op_,
arg.a_element_op_,
arg.b_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.c_element_op_,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg.p_dropout_in_16bits_,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg.p_dropout_rescale_,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
vgrad_grid_desc_n_o_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_
,
arg.seed_,
arg.seed_,
arg.offset_);
arg.offset_);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
// to concern Gemm0's loop
#if 1
if(all_has_main_k_block_loop)
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
#endif
}
else if(!some_has_main_k_block_loop)
{
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
else
{
throw std::runtime_error("wrong! all gemm problems have to simultaneously meet "
"has_main_k_block_loop or no_main_k_block_loop");
}
return ave_time;
return ave_time;
}
}
...
@@ -989,36 +839,44 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -989,36 +839,44 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
#if 0
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
{
return false;
return false;
}
}
// TODO: Check if tensor specialization & strides mismatch
// TODO
ANT
: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
bool all_has_main_k_block_loop = true;
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
bool some_has_main_k_block_loop = false;
const
index_t
c_m
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
for(std::size_t i = 0; i < arg.group_count_; i++)
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
{
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i];
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
{
{
return false;
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// Check if having main loop
// vector is out of bounds
const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) *
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const
auto
MzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
];
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
const
auto
NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
];
all_has_main_k_block_loop &= y;
const
auto
KzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
2
];
some_has_main_k_block_loop |= y;
const
auto
Gemm1NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
3
];
// Note: we need raw lengths since threadwise copy can not handle vector load when
// part of vector is out of bounds
const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
...
@@ -1035,14 +893,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1035,14 +893,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
}
}
// Check vector load/store requirement
// Check vector load/store requirement
const
auto
a_stride_lowest
=
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
? device_arg.a_mz_kz_strides_[1]
const
auto
b_stride_lowest
=
: device_arg.a_mz_kz_strides_[0];
BBlockTransferSrcVectorDim
==
2
?
arg
.
b_nz_kz_strides_
[
1
]
:
arg
.
b_nz_kz_strides_
[
0
];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
const
auto
b1_stride_lowest
=
? device_arg.b_nz_kz_strides_[1]
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_kz_strides_
[
1
]
:
arg
.
b1_nz_kz_strides_
[
0
];
: device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
const auto c_stride_lowest =
arg
.
c_mz_gemm1nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be contiguous
device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
// contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
c_stride_lowest == 1))
...
@@ -1050,11 +912,24 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1050,11 +912,24 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
return false;
return false;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
if(!GridwiseGemm::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_,
arg
.
b_grid_desc_bk0_n_bk1_
,
kernel_arg.b_grid_desc_bk0_n_bk1_,
arg
.
b1_grid_desc_bk0_n_bk1_
,
kernel_arg.b1_grid_desc_bk0_n_bk1_,
arg
.
y_grid_desc_m_o_
,
device_arg.c_grid_desc_m_n_,
arg
.
block_2_ctile_map_
);
kernel_arg.block_2_ctile_map_))
{
return false;
}
}
// all gemm problems have to simultaneously meet has_main_k_block_loop or
// no_main_k_block_loop
if(!(all_has_main_k_block_loop || !some_has_main_k_block_loop))
{
return false;
}
return true;
}
}
// polymorphic
// polymorphic
...
@@ -1063,160 +938,82 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1063,160 +938,82 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
}
static
auto
MakeArgument
(
static auto MakeArgument(std::vector<const void*> p_a_vec,
const
DataType
*
p_a
,
std::vector<const void*> p_b_vec,
const
DataType
*
p_b
,
std::vector<const void*> p_b1_vec,
ZDataType
*
p_z
,
std::vector<void*> p_c_vec,
const
DataType
*
p_b1
,
std::vector<void*> p_z_vec,
const
DataType
*
p_c
,
std::vector<void*> p_lse_vec,
const
LSEDataType
*
p_lse
,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
const
DataType
*
p_ygrad_grid
,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
DataType
*
p_qgrad_grid
,
std::vector<ProblemDesc> problem_desc_vec,
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
CElementwiseOperation c_element_op,
float
p_drop
,
float p_drop
out,
std::tuple<unsigned long long, unsigned long long> seeds)
std::tuple<unsigned long long, unsigned long long> seeds)
{
{
return
Argument
{
p_a
,
return Argument{p_a_vec,
p_b
,
p_b_vec,
p_z
,
p_b1_vec,
p_b1
,
p_c_vec,
p_c
,
p_z_vec,
p_lse
,
p_lse_vec,
p_ygrad_grid
,
p_acc0_biases_vec,
p_qgrad_grid
,
p_acc1_biases_vec,
p_kgrad_grid
,
problem_desc_vec,
p_vgrad_grid
,
p_acc0_biases
,
p_acc1_biases
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
a_element_op,
a_element_op,
b_element_op,
b_element_op,
acc_element_op,
acc_element_op,
b1_element_op,
b1_element_op,
c_element_op,
c_element_op,
p_drop
,
p_drop
out,
seeds};
seeds};
}
}
static auto MakeInvoker() { return Invoker{}; }
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
MakeArgumentPointer(std::vector<const void*> p_a_vec,
const
void
*
p_a
,
std::vector<const void*> p_b_vec,
const
void
*
p_b
,
std::vector<const void*> p_b1_vec,
void
*
p_z
,
std::vector<void*> p_c_vec,
const
void
*
p_b1
,
std::vector<void*> p_z_vec,
const
void
*
p_c
,
std::vector<void*> p_lse_vec,
const
void
*
p_lse
,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
const
void
*
p_ygrad_grid
,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
void
*
p_qgrad_grid
,
std::vector<ProblemDesc> problem_desc_vec,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
CElementwiseOperation c_element_op,
float
p_drop
,
float p_drop
out,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
//
override
std::tuple<unsigned long long, unsigned long long> seeds) override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
return std::make_unique<Argument>(p_a_vec,
static_cast
<
const
DataType
*>
(
p_b
),
p_b_vec,
static_cast
<
ZDataType
*>
(
p_z
),
p_b1_vec,
static_cast
<
const
DataType
*>
(
p_b1
),
p_c_vec,
static_cast
<
const
DataType
*>
(
p_c
),
p_z_vec,
static_cast
<
const
LSEDataType
*>
(
p_lse
),
p_lse_vec,
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
p_acc0_biases_vec,
static_cast
<
DataType
*>
(
p_qgrad_grid
),
p_acc1_biases_vec,
static_cast
<
DataType
*>
(
p_kgrad_grid
),
problem_desc_vec,
static_cast
<
DataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
acc1_biases_gs_ms_gemm1ns_strides
,
a_element_op,
a_element_op,
b_element_op,
b_element_op,
acc_element_op,
acc_element_op,
b1_element_op,
b1_element_op,
c_element_op,
c_element_op,
p_drop
,
p_drop
out,
seeds);
seeds);
}
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
//
override
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
{
return std::make_unique<Invoker>(Invoker{});
return std::make_unique<Invoker>(Invoker{});
}
}
...
@@ -1227,7 +1024,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1227,7 +1024,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
auto str = std::stringstream();
auto str = std::stringstream();
// clang-format off
// clang-format off
str
<<
"Device
BatchedMultiheadAttentionBackward
_Train_Xdl_CShuffle"
str << "Device
GroupedGemmSoftmaxGemmPermute
_Train_Xdl_CShuffle"
<< "<"
<< "<"
<< BlockSize << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< MPerBlock << ", "
...
@@ -1249,6 +1046,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
...
@@ -1249,6 +1046,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
return str.str();
return str.str();
}
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
}
};
};
} // namespace device
} // namespace device
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
66052232
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
@@ -15,6 +16,7 @@
...
@@ -15,6 +16,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -30,6 +32,7 @@ template <typename DataType,
...
@@ -30,6 +32,7 @@ template <typename DataType,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
QGridDesc_K0_M_K1
,
typename
QGridDesc_K0_M_K1
,
typename
KGridDesc_K0_N_K1
,
typename
KGridDesc_K0_N_K1
,
typename
ZGridDesc_M_N
,
typename
VGridDesc_N0_O_N1
,
typename
VGridDesc_N0_O_N1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
LSEGridDesc_M
,
typename
LSEGridDesc_M
,
...
@@ -80,8 +83,23 @@ template <typename DataType,
...
@@ -80,8 +83,23 @@ template <typename DataType,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
{
{
template
<
typename
T
>
struct
TypeMap
{
using
type
=
T
;
};
#if defined(__gfx90a__)
template
<
>
struct
TypeMap
<
ck
::
half_t
>
{
using
type
=
ck
::
bhalf_t
;
};
#endif
using
LDSDataType
=
typename
TypeMap
<
DataType
>::
type
;
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -93,7 +111,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -93,7 +111,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
// Gemm0
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
...
@@ -113,6 +134,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -113,6 +134,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
LDSDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
{
constexpr
auto
mfma
=
MfmaSelector
<
LDSDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
@@ -347,6 +427,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -347,6 +427,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
// S / dP Gemm (type 1 rcr)
// S / dP Gemm (type 1 rcr)
struct
Gemm0
struct
Gemm0
{
{
...
@@ -388,7 +471,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -388,7 +471,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
DataType
,
LDS
DataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -413,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -413,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
DataType
,
LDS
DataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -428,13 +511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -428,13 +511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
true
,
// DstResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
static
constexpr
index_t
KPack
=
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
LDSDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
DataType
,
LDS
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -496,7 +580,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -496,7 +580,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
LDS
DataType
,
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -515,7 +599,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -515,7 +599,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
DataType
,
LDS
DataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcAccessOrder
,
...
@@ -546,11 +630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -546,11 +630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
DataType
,
LDS
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -566,7 +650,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -566,7 +650,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
GemmKPack
,
GemmKPack
,
true
,
// TransposeC
true
,
// TransposeC
GemmKPack
,
// AMmaKStride
GemmKPack
,
// AMmaKStride
GemmKPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
GemmKPack
*
XdlopsGemm
<
LDS
DataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
.
K0PerXdlops
/* BMmaKStride */
>
;
.
K0PerXdlops
/* BMmaKStride */
>
;
};
};
...
@@ -598,7 +682,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -598,7 +682,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMPack
=
static
constexpr
index_t
GemmMPack
=
math
::
max
(
math
::
lcm
(
A_M1
,
B_M1
),
math
::
max
(
math
::
lcm
(
A_M1
,
B_M1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_M0
,
Free1_O
,
B_M1
>
;
using
BBlockSliceLengths
=
Sequence
<
B_M0
,
Free1_O
,
B_M1
>
;
using
BThreadClusterLengths
=
using
BThreadClusterLengths
=
...
@@ -720,12 +804,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -720,12 +804,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
,
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
,
false
>
;
false
>
;
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
LDS
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
e
lement
_
wise
::
PassThrough
,
E
lementwise
Op
,
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I0
),
// ThreadSliceLengths
I0
),
// ThreadSliceLengths
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
...
@@ -752,7 +837,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -752,7 +837,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
DataType
,
DataType
,
DataType
,
LDS
DataType
,
GridDesc_M0_O_M1
,
GridDesc_M0_O_M1
,
decltype
(
b_block_desc_m0_o_m1
),
decltype
(
b_block_desc_m0_o_m1
),
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
// access order == thread order
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
// access order == thread order
...
@@ -769,7 +854,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -769,7 +854,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
using
BlockwiseGemm
=
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
LDS
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
b_block_desc_m0_o_m1
),
decltype
(
b_block_desc_m0_o_m1
),
...
@@ -836,7 +921,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -836,7 +921,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
{
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
ThreadClusterLength_O
=
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
@@ -848,7 +933,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -848,7 +933,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
FloatGemmAcc
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
true
>
;
...
@@ -1010,7 +1095,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1010,7 +1095,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static
constexpr
auto
b2_block_desc_m0_o_m1
=
static
constexpr
auto
b2_block_desc_m0_o_m1
=
GetB2BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
GetB2BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
LDS
DataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
@@ -1046,13 +1131,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1046,13 +1131,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
{
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
LDS
DataType
);
const
index_t
gemm1_bytes_end
=
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
LDS
DataType
);
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
LDS
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
SharedMemTrait
::
reduction_space_size_aligned
)
*
...
@@ -1074,6 +1159,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1074,6 +1159,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename
YGradGridDesc_M0_O_M1
>
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
const
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
...
@@ -1089,6 +1175,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1089,6 +1175,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
...
@@ -1096,8 +1184,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1096,8 +1184,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
const
VGradGridDescriptor_N_O
&
vgrad_grid_desc_n_o
,
const
VGradGridDescriptor_N_O
&
vgrad_grid_desc_n_o
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
const
C0MatrixMask
&
c0_matrix_mask
,
FloatGemmAcc
p_dropout
,
ck
::
philox
&
ph
)
{
{
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
FloatGemmAcc
rp_dropout
=
1.0
f
/
p_dropout
;
const
auto
q_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
q_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_q_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_q_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
k_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
k_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -1147,11 +1240,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1147,11 +1240,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// Gemm0: LDS allocation for A and B: be careful of alignment
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// Gemm0: gridwise GEMM pipeline
// Gemm0: gridwise GEMM pipeline
...
@@ -1243,11 +1336,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1243,11 +1336,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Gemm1: VGPR allocation for A and LDS allocation for B
// Gemm1: VGPR allocation for A and LDS allocation for B
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
LDS
DataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm1_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm1_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
Gemm1
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
Gemm1
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// dQ: transform input and output tensor descriptors
// dQ: transform input and output tensor descriptors
...
@@ -1331,6 +1424,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1331,6 +1424,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_16bits
,
rp_dropout
};
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
...
@@ -1360,6 +1456,75 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1360,6 +1456,75 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
acc0_thread_origin
[
I4
])};
// mperxdl
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// registerNum
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
n4
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
//
//
// set up dV / dK Gemm (type 3 crr)
// set up dV / dK Gemm (type 3 crr)
//
//
...
@@ -1367,11 +1532,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1367,11 +1532,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// Gemm2: LDS allocation for A and B: be careful of alignment
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
// dV: transform input and output tensor descriptors
// dV: transform input and output tensor descriptors
...
@@ -1379,10 +1544,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1379,10 +1544,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
// dV: A matrix VGPR-to-LDS blockwise copy
// dV: A matrix VGPR-to-LDS blockwise copy
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
Relu
>{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
Relu
{}};
// relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy
// dV: B matrix global-to-LDS blockwise copy
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
...
@@ -1407,11 +1573,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1407,11 +1573,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
make_multi_index
(
make_multi_index
(
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
vgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
auto
vgrad_thread_copy_vgpr_to_global
=
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
)>(
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
tensor_operation
::
element_wise
::
Scale
>(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
tensor_operation
::
element_wise
::
PassThrough
{
});
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
// dK: transform input and output tensor descriptors
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
const
auto
q_grid_desc_m0_k_m1
=
...
@@ -1422,7 +1590,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1422,7 +1590,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
// dK: A matrix VGPR-to-LDS blockwise copy
// dK: A matrix VGPR-to-LDS blockwise copy
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
PassThrough
>{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
...
@@ -1487,7 +1656,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1487,7 +1656,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// performs double duty for both y and ygrad
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
DataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
...
@@ -1496,7 +1665,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1496,7 +1665,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
true
/* ResetCoordAfterRun */
,
tru
e
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
fals
e
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
y_thread_data_on_grid_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
...
@@ -1574,7 +1743,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1574,7 +1743,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]
);
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]
*
p_dropout
);
// p_dropoutD1
});
});
block_sync_lds
();
block_sync_lds
();
...
@@ -1595,6 +1764,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1595,6 +1764,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
const
index_t
K
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
const
float
scalar
=
1.0
f
/
std
::
sqrt
(
K
);
// Initialize dQ
// Initialize dQ
qgrad_thread_buf
.
Clear
();
qgrad_thread_buf
.
Clear
();
...
@@ -1675,14 +1846,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1675,14 +1846,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
}
}
else
else
{
{
s_element_op
(
s_slash_p_thread_buf
(
i
)
,
s_slash_p_thread_buf
[
i
]
)
;
s_slash_p_thread_buf
(
i
)
=
scalar
*
s_slash_p_thread_buf
[
i
];
}
}
});
});
}
}
else
else
{
{
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}(
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
s_
element_op
(
acc
_thread_buf
(
i
)
,
s_slash_p_thread_buf
[
i
]
)
;
});
[
&
](
auto
i
)
{
s_
slash_p
_thread_buf
(
i
)
=
scalar
*
s_slash_p_thread_buf
[
i
];
});
}
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
@@ -1691,6 +1862,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1691,6 +1862,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// scaling is already performed in the preceding statements with s_element_op
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
if
(
p_z_grid
)
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
}
else
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
...
@@ -1701,7 +1894,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1701,7 +1894,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
""
);
""
);
// TODO: tune gemm2 pipeline
// TODO: tune gemm2 pipeline
// dV = P^T * dY
// dV = P
_drop
^T * dY
v_slash_k_grad_thread_buf
.
Clear
();
v_slash_k_grad_thread_buf
.
Clear
();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dV
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dV
// load VGrad Gemm B
// load VGrad Gemm B
...
@@ -1781,8 +1974,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1781,8 +1974,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
constexpr
auto
m
=
constexpr
auto
m
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
// dS and P has same thread buf layout
// dS and P has same thread buf layout
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
if
(
s_slash_p_thread_buf
[
i
]
>=
0
)
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
}
else
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}];
}
});
});
// gemm dQ
// gemm dQ
...
@@ -1922,6 +2124,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1922,6 +2124,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle dQ and write
// shuffle dQ and write
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
deleted
100644 → 0
View file @
5eb5e316
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace
ck
{
template
<
typename
DataType
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatLSE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
SElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
QGridDesc_K0_M_K1
,
typename
KGridDesc_K0_N_K1
,
typename
ZGridDesc_M_N
,
typename
VGridDesc_N0_O_N1
,
typename
CGridDesc_M_N
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
B1K1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
index_t
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
// Gemm1
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
{
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
1
,
1
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
Gemm1NWaves
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
Gemm1NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
template
<
typename
AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
>
__host__
__device__
static
constexpr
auto
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
const
AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
&
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
{
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to a_src_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const
auto
m0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
const
auto
n0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
const
auto
m1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
const
auto
n1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
const
auto
m2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
const
auto
n2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
const
auto
n3
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
const
auto
n4
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
return
transform_tensor_descriptor
(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
B1K0
,
Number
<
Gemm1NPerBlock
>
{},
B1K1
),
make_tuple
(
Number
<
Gemm1NPerBlock
+
B1BlockLdsExtraN
>
{}
*
B1K1
,
B1K1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
template
<
typename
Gemm2Param
>
__host__
__device__
static
constexpr
auto
GetA2BlockDescriptor_M0_N_M1
()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
Gemm2Param
::
A_M0
>
{},
Number
<
Gemm2Param
::
Free0_N
>
{},
Number
<
Gemm2Param
::
A_M1
>
{}),
make_tuple
(
Number
<
Gemm2Param
::
Free0_N
+
Gemm2Param
::
A_LdsPad
>
{}
*
Number
<
Gemm2Param
::
A_M1
>
{},
Number
<
Gemm2Param
::
A_M1
>
{},
I1
));
}
template
<
typename
Gemm2Param
>
__host__
__device__
static
constexpr
auto
GetB2BlockDescriptor_M0_O_M1
()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
Gemm2Param
::
B_M0
>
{},
Number
<
Gemm2Param
::
Free1_O
>
{},
Number
<
Gemm2Param
::
B_M1
>
{}),
make_tuple
(
Number
<
Gemm2Param
::
Free1_O
+
Gemm2Param
::
B_LdsPad
>
{}
*
Number
<
Gemm2Param
::
B_M1
>
{},
Number
<
Gemm2Param
::
B_M1
>
{},
I1
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
// check gemm0 gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
return
false
;
}
// check gemm1 gridwise gemm pipeline
if
(
!
(
NPerBlock
%
Gemm1KPerBlock
==
0
))
{
return
false
;
}
const
auto
num_gemm1_k_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inner_loop
))
{
return
false
;
}
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
Gemm1NPerBlock
;
const
auto
y_grid_desc_mblock_mperblock_oblock_operblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
y_grid_desc_mblock_mperblock_oblock_operblock
;
}
__host__
__device__
static
constexpr
auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
const
LSEGridDesc_M
&
lse_grid_desc_m
)
{
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
const
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}));
return
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
// S / dP Gemm (type 1 rcr)
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
GridDesc_K0_M_K1
>
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
;
// TransposeC
static
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
};
// Y / dQ Gemm (type 2 rrr)
template
<
typename
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
,
typename
ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
>
struct
Gemm1
{
private:
static
constexpr
auto
m0
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I0
);
static
constexpr
auto
n0
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I1
);
static
constexpr
auto
m1
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I2
);
static
constexpr
auto
n1
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I3
);
static
constexpr
auto
m2
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I4
);
static
constexpr
auto
n2
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I5
);
static
constexpr
auto
n3
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I6
);
static
constexpr
auto
n4
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I7
);
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
static
constexpr
auto
N3
=
ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I6
);
public:
static
constexpr
auto
AThreadSliceLength_K0
=
Number
<
Gemm1KPerBlock
/
n4
/
N3
>
{};
static
constexpr
auto
AThreadSliceLength_M
=
Number
<
m0
*
m1
*
m2
>
{};
static
constexpr
auto
AThreadSliceLength_K1
=
Number
<
n4
>
{};
// A source matrix layout in AccVGPR
static
constexpr
auto
a_src_thread_desc_k0_m_k1
=
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static
constexpr
auto
a_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
AThreadSliceLength_K0
,
AThreadSliceLength_M
,
AThreadSliceLength_K1
));
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
ASrcScalarPerVector
=
n4
;
using
AThreadSliceLengths_K0_M_K1
=
decltype
(
a_thread_desc_k0_m_k1
.
GetLengths
());
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
AThreadSliceLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
2
,
ASrcScalarPerVector
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static
constexpr
auto
a_block_slice_copy_step
=
make_tuple
(
AThreadSliceLength_K0
,
I0
,
I0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
Gemm1NXdlPerWave
,
GemmKPack
,
true
,
// TransposeC
GemmKPack
,
// AMmaKStride
GemmKPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
.
K0PerXdlops
/* BMmaKStride */
>
;
};
// dV / dK Gemm (type 3 crr)
// Describes tuning parameter for C2_n_o = A2_n_m * B2_m_o
template
<
index_t
Sum_M_
=
MPerXdl
*
2
>
struct
Gemm2Params_N_O_M_
{
static
constexpr
index_t
Free0_N
=
NPerBlock
;
static
constexpr
index_t
Free1_O
=
Gemm1NPerBlock
;
static
constexpr
index_t
Sum_M
=
Sum_M_
;
static
constexpr
index_t
A_M1
=
8
;
// P will be row-major
static
constexpr
index_t
A_M0
=
Sum_M
/
A_M1
;
static
constexpr
index_t
A_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static
constexpr
index_t
B_M1
=
2
;
// dY assumed row-major, typically =2 for fp16
static
constexpr
index_t
B_M0
=
Sum_M
/
B_M1
;
static
constexpr
index_t
B_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static
constexpr
index_t
BSrcVectorDim
=
1
;
// Free1_O dimension
static
constexpr
index_t
BSrcScalarPerVector
=
4
;
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMPack
=
math
::
max
(
math
::
lcm
(
A_M1
,
B_M1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_M0
,
Free1_O
,
B_M1
>
;
using
BThreadClusterLengths
=
Sequence
<
BlockSize
/
(
Free1_O
/
BSrcScalarPerVector
),
Free1_O
/
BSrcScalarPerVector
,
1
>
;
using
BThreadClusterArrangeOrder
=
Sequence
<
0
,
2
,
1
>
;
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_N0_M1_N1_M2_N2
()
{
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr
index_t
m
=
Gemm2Params_N_O_M
::
Sum_M
-
1
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr
index_t
n
=
Gemm2Params_N_O_M
::
Free0_N
-
1
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_N0_M1_N1
()
{
return
generate_sequence_v2
(
[](
auto
I
)
{
return
GetABlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
Number
<
4
>
{});
}
using
ABlockSliceLengths_M0_N0_M1_N1
=
decltype
(
GetABlockSliceLengths_M0_N0_M1_N1
());
};
using
Gemm2Params_N_O_M
=
Gemm2Params_N_O_M_
<>
;
// tune later
// dV / dK Gemm (type 3 crr)
template
<
typename
Gemm2Params_N_O_M
,
typename
ASrcBlockwiseGemm
>
struct
Gemm2
{
private:
static
constexpr
auto
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
ASrcBlockwiseGemm
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
static
constexpr
auto
M0
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
// repeat
static
constexpr
auto
N0
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
static
constexpr
auto
M1
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
// wave
static
constexpr
auto
N1
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
static
constexpr
auto
M2
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
// xdl
static
constexpr
auto
N2
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
static
constexpr
auto
N3
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
static
constexpr
auto
N4
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static
constexpr
auto
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
ASrcBlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_m0_n_m1
=
GetA2BlockDescriptor_M0_N_M1
<
Gemm2Params_N_O_M
>
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_m0_o_m1
=
GetB2BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
__host__
__device__
static
constexpr
auto
MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
()
{
const
auto
M0_
=
a_block_desc_m0_n_m1
.
GetLength
(
I0
);
const
auto
N_
=
a_block_desc_m0_n_m1
.
GetLength
(
I1
);
const
auto
M1_
=
a_block_desc_m0_n_m1
.
GetLength
(
I2
);
const
auto
a_block_desc_m_n
=
transform_tensor_descriptor
(
a_block_desc_m0_n_m1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
M0_
,
M1_
)),
make_pass_through_transform
(
N_
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return
transform_tensor_descriptor
(
a_block_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
I1
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__
__device__
static
auto
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
()
{
const
auto
a_thread_origin_on_block_idx
=
ASrcBlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
c_block_slice_lengths_m0_n0_m1_n1
=
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
{};
// mrepeat, nrepeat,
// mwaves, nwaves,
return
make_tuple
(
a_thread_origin_on_block_idx
[
I0
],
// mrepeat
a_thread_origin_on_block_idx
[
I1
],
// nrepeat
a_thread_origin_on_block_idx
[
I2
]
%
c_block_slice_lengths_m0_n0_m1_n1
[
I2
],
// mwave
a_thread_origin_on_block_idx
[
I3
]
%
c_block_slice_lengths_m0_n0_m1_n1
[
I3
],
// nwave
a_thread_origin_on_block_idx
[
I4
],
// xdlops
a_thread_origin_on_block_idx
[
I5
],
a_thread_origin_on_block_idx
[
I6
],
a_thread_origin_on_block_idx
[
I7
]);
}
static
constexpr
auto
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
();
using
ASrcBlockSliceWindowIterator
=
SpaceFillingCurve
<
Sequence
<
M0
,
N0
,
M1
,
N1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
,
false
>
;
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
ElementwiseOp
,
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I0
),
// ThreadSliceLengths
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
;
template
<
typename
GridDesc_M0_O_M1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
typename
Gemm2Params_N_O_M
::
BBlockSliceLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_M0_O_M1
,
decltype
(
b_block_desc_m0_o_m1
),
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
// access order == thread order
Sequence
<
1
,
0
,
2
>
,
Gemm2Params_N_O_M
::
BSrcVectorDim
,
2
,
// DstVectorDim
Gemm2Params_N_O_M
::
BSrcScalarPerVector
,
Gemm2Params_N_O_M
::
B_M1
,
1
,
1
,
true
,
true
,
1
>
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
b_block_desc_m0_o_m1
),
MPerXdl
,
NPerXdl
,
Gemm2Params_N_O_M
::
GemmNRepeat
,
Gemm2Params_N_O_M
::
GemmORepeat
,
Gemm2Params_N_O_M
::
GemmMPack
,
true
>
;
// TranspossC
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
Gemm2Params_N_O_M
::
B_M0
,
0
,
0
);
static
constexpr
auto
c_block_slice_copy_step
=
make_multi_index
(
Gemm2Params_N_O_M
::
GemmNRepeat
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
static
constexpr
auto
b_block_reset_copy_step
=
make_multi_index
(
-
MPerBlock
/
Gemm2Params_N_O_M
::
B_M1
,
0
,
0
);
template
<
typename
CGradDesc_N_O
>
__host__
__device__
static
auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
const
CGradDesc_N_O
&
c_grid_desc_n_o
)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const
auto
c_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
c_grid_desc_n_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
Gemm2Params_N_O_M
::
GemmNWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
Gemm2Params_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
const
auto
c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
BlockwiseGemm
{}.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_grid_desc_n0_o0_n1_o1_n2_o2
);
return
c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
;
}
static
constexpr
auto
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
BlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
__host__
__device__
static
auto
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
{
return
to_multi_index
(
BlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
));
}
template
<
typename
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
ElementwiseOp
,
// CElementwiseOperation
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
{
// TODO:
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
template
<
typename
YGradGridDesc_M0_O_M1_
>
__device__
static
auto
MakeYGradGridDesc_O0_M_O1
(
const
YGradGridDesc_M0_O_M1_
&
ygrad_grid_desc_m0_o_m1
)
{
const
auto
M0
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I0
);
const
auto
O
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I1
);
const
auto
M1
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I2
);
constexpr
auto
Y_O1
=
AK1
;
const
auto
Y_O0
=
O
/
Y_O1
;
const
auto
ygrad_grid_desc_o0_m_o1
=
transform_tensor_descriptor
(
ygrad_grid_desc_m0_o_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_O0
,
Y_O1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
ygrad_grid_desc_o0_m_o1
;
}
template
<
typename
VGridDesc_N0_O_N1_
>
__device__
static
auto
MakeVGridDesc_O0_N_O1
(
const
VGridDesc_N0_O_N1_
&
v_grid_desc_n0_o_n1
)
{
const
auto
N0
=
v_grid_desc_n0_o_n1
.
GetLength
(
I0
);
const
auto
O
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
const
auto
N1
=
v_grid_desc_n0_o_n1
.
GetLength
(
I2
);
constexpr
auto
V_O1
=
BK1
;
const
auto
V_O0
=
O
/
V_O1
;
const
auto
v_grid_desc_o0_n_o1
=
transform_tensor_descriptor
(
v_grid_desc_n0_o_n1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
V_O0
,
V_O1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
v_grid_desc_o0_n_o1
;
}
};
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
struct
QGradGemmTile_M_K_N
{
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
auto
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
K0
*
K1
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
KBlock
=
K
/
Gemm1NPerBlock
;
// NOTE: QGrad gemm is similar to Y gemm
const
auto
q_grid_desc_m_k
=
transform_tensor_descriptor
(
q_grid_desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
M
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
q_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
KBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
auto
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K_K1
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
constexpr
auto
K_N1
=
B1K1
;
const
auto
K_N0
=
N
/
K_N1
;
const
auto
k_grid_desc_n0_k_n1
=
transform_tensor_descriptor
(
k_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K_N0
,
K_N1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
K_K0
,
K_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
k_grid_desc_n0_k_n1
;
}
};
struct
KGradGemmTile_N_K_M
{
// B position
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
auto
MakeQGridDesc_M0_K_M1
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
Q_K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
Q_K1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
constexpr
auto
Q_M1
=
B1K1
;
const
auto
Q_M0
=
M
/
Q_M1
;
const
auto
q_grid_desc_m0_k_m1
=
transform_tensor_descriptor
(
q_grid_desc_k0_m_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Q_M0
,
Q_M1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
Q_K0
,
Q_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
q_grid_desc_m0_k_m1
;
}
// C position
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
auto
MakeKGradGridDesc_N_K
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K_K1
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
k_grid_desc_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
N
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K_K0
,
K_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
};
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
a2_block_desc_m0_n_m1
=
GetA2BlockDescriptor_M0_N_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
b2_block_desc_m0_o_m1
=
GetB2BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_block_space_size_aligned
=
math
::
integer_least_multiple
(
a2_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
b2_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
a2_block_space_offset
=
0
;
static
constexpr
auto
b2_block_space_offset
=
p_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
vgrad_gemm_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
SElementwiseOperation
&
s_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
VGradGridDescriptor_N_O
&
vgrad_grid_desc_n_o
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
FloatGemmAcc
p_dropout
,
ck
::
philox
&
ph
)
{
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
FloatGemmAcc
rp_dropout
=
1.0
f
/
p_dropout
;
const
auto
q_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_q_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
k_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_k_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
const
auto
v_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_v_grid
,
v_grid_desc_n0_o_n1
.
GetElementSpaceSize
());
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [M, O]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I0
),
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
o_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
//
// set up S / dP Gemm (type 1 rcr)
//
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// Gemm0: gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gemm0_gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
// S: A matrix blockwise copy
auto
s_gemm_tile_q_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// S: B matrix blockwise copy
auto
s_gemm_tile_k_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
k_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
b_element_op
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// S: blockwise gemm
auto
s_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
// TransposeC
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
s_gemm_tile_a_block_reset_copy_step
=
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
const
auto
s_gemm_tile_b_block_reset_copy_step
=
make_multi_index
(
-
k_grid_desc_k0_n_k1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
// dP: transform input and output tensor descriptors
const
auto
ygrad_grid_desc_o0_m_o1
=
PGradGemmTile_M_N_O
::
MakeYGradGridDesc_O0_M_O1
(
ygrad_grid_desc_m0_o_m1
);
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// dP: A matrix blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP: B matrix blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
0
,
0
);
const
auto
pgrad_gemm_tile_v_block_reset_copy_step
=
make_multi_index
(
-
v_grid_desc_o0_n_o1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
index_t
num_o_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
)
*
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I2
))
/
KPerBlock
);
//
// set up Y / dQ Gemm (type 2 rrr)
//
// Note: Y is pre-calculated in forward pass and loaded to backward pass kernel
using
Gemm1
=
Gemm1
<
decltype
(
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()),
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Gemm1: VGPR allocation for A and LDS allocation for B
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm1_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
Gemm1
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// dQ: transform input and output tensor descriptors
const
auto
k_grid_desc_n0_k_n1
=
QGradGemmTile_M_K_N
::
MakeKGridDesc_N0_K_N1
(
k_grid_desc_k0_n_k1
);
auto
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
=
QGradGemmTile_M_K_N
::
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
q_grid_desc_k0_m_k1
);
// dQ: A matrix blockwise copy
auto
qgrad_gemm_tile_sgrad_blockwise_copy
=
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ: B matrix blockwise copy
auto
qgrad_gemm_tile_k_blockwise_copy
=
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_n0_k_n1
)>(
k_grid_desc_n0_k_n1
,
make_multi_index
(
0
,
o_block_data_idx_on_grid
,
0
),
b1_element_op
,
Gemm1
::
b_block_desc_bk0_n_bk1
,
// there n actually is k, k is N, so name can be
// b_block_desc_bn0_k_bn1
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dQ: blockwise gemm
auto
qgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
qgrad_thread_buf
=
qgrad_blockwise_gemm
.
GetCThreadBuffer
();
//
// Blockwise softmax
//
// get acc0 8D thread cluster
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
tm0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I0
);
constexpr
auto
tn0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I1
);
constexpr
auto
tm1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I2
);
constexpr
auto
tn1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I3
);
constexpr
auto
tm2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I4
);
constexpr
auto
tn2
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I5
);
constexpr
auto
tn3
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I6
);
constexpr
auto
tn4
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I7
);
// get acc0 thread map
constexpr
auto
m0_n_m1_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
tm0
*
tm1
,
tm2
)),
make_pass_through_transform
(
I1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
threadid_to_m0_n_m1_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
tm0
*
tm1
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
,
tm2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
threadid_to_m_n_thread_cluster_adaptor
=
chain_tensor_adaptors
(
m0_n_m1_to_m_n_adaptor
,
threadid_to_m0_n_m1_adaptor
);
// get acc0 2D thread cluster & 2D thread slice
constexpr
auto
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
thread_cluster_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
tm0
*
tm1
*
tm2
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
));
constexpr
auto
thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
*
m1
*
m2
,
n0
*
n1
*
n2
*
n3
*
n4
));
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_16bits
,
rp_dropout
};
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
constexpr
auto
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
m0
,
m1
,
m2
));
auto
lse_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatLSE
>
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
auto
acc0_thread_origin
=
s_blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
auto
lse_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatLSE
,
FloatLSE
,
decltype
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// registerNum
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
n4
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
//
// set up dV / dK Gemm (type 3 crr)
//
using
Gemm2
=
Gemm2
<
Gemm2Params_N_O_M
,
decltype
(
s_blockwise_gemm
)
>
;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
// dV: transform input and output tensor descriptors
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
// dV: A matrix VGPR-to-LDS blockwise copy
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
Relu
>{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
Relu
{}};
// relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
ygrad_grid_desc_m0_o_m1
)>(
ygrad_grid_desc_m0_o_m1
,
make_multi_index
(
m_block_data_idx_on_grid
/
Gemm2Params_N_O_M
::
B_M1
,
o_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm2
::
b_block_desc_m0_o_m1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dV: blockwise gemm
auto
v_slash_k_grad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
auto
v_slash_k_grad_thread_buf
=
v_slash_k_grad_blockwise_gemm
.
GetCThreadBuffer
();
// dV: C VGPR-to-global copy
const
auto
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
make_multi_index
(
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
vgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
tensor_operation
::
element_wise
::
Scale
>(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
KGradGemmTile_N_K_M
::
MakeQGridDesc_M0_K_M1
(
q_grid_desc_k0_m_k1
);
const
auto
kgrad_grid_desc_n_k
=
KGradGemmTile_N_K_M
::
MakeKGradGridDesc_N_K
(
k_grid_desc_k0_n_k1
);
const
auto
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
// dK: A matrix VGPR-to-LDS blockwise copy
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
PassThrough
>{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
// dK: B matrix global-to-LDS blockwise copy
auto
kgrad_gemm_tile_q_blockwise_copy
=
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
q_grid_desc_m0_k_m1
)>(
q_grid_desc_m0_k_m1
,
make_multi_index
(
m_block_data_idx_on_grid
/
Gemm2Params_N_O_M
::
B_M1
,
o_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm2
::
b_block_desc_m0_o_m1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dK: blockwise gemm
/* reuse v_slash_k_grad_blockwise_gemm, v_slash_k_grad_thread_buf */
// dK: C VGPR-to-global copy
const
auto
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
make_multi_index
(
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
kgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
s_element_op
)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
s_element_op
);
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
constexpr
auto
P_M2
=
p_block_lengths
[
I4
];
// xdl
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
const
auto
y_thread_cluster_idx
=
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
block_work_idx
[
I0
],
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
y_thread_data_on_block_idx
;
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
true
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_O
::
DstBufType
{};
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
),
MPerBlock
);
constexpr
auto
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
),
make_tuple
(
P_M0
*
P_M1
*
P_M2
,
P_M1
*
P_M2
,
P_M2
,
I1
));
constexpr
auto
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
=
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
;
// reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
//
// calculate Y dot dY
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
index_t
oblock_idx
=
0
;
do
{
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_thread_buf
);
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
ygrad_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
ygrad_thread_buf
);
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_O
,
1
>
{}([
&
](
auto
iO
)
{
constexpr
auto
offset
=
y_thread_desc_m0_m1_o0_o1
.
CalculateOffset
(
make_multi_index
(
I0
,
iM
,
I0
,
iO
));
y_dot_ygrad_thread_accum_buf
(
iM
)
+=
y_thread_buf
[
Number
<
offset
>
{}]
*
ygrad_thread_buf
[
Number
<
offset
>
{}];
});
});
yygrad_threadwise_copy
.
MoveSrcSliceWindow
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
0
,
0
,
1
,
0
));
oblock_idx
++
;
}
while
(
oblock_idx
<
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I2
));
// blockwise reduction using atomic_add
block_sync_lds
();
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]
*
p_dropout
);
// p_dropoutD1
});
block_sync_lds
();
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr
.
Run
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
y_dot_ygrad_block_accum_buf
,
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_buf
,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
const
index_t
K
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
const
float
scalar
=
1.0
f
/
std
::
sqrt
(
K
);
// Initialize dQ
qgrad_thread_buf
.
Clear
();
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
{
auto
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
{
continue
;
}
// S = Q * K^T
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
q_grid_desc_k0_m_k1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
s_gemm_tile_q_blockwise_copy
,
q_grid_buf
,
gemm0_a_block_buf
,
Gemm0
::
a_block_slice_copy_step
,
k_grid_desc_k0_n_k1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
s_gemm_tile_k_blockwise_copy
,
k_grid_buf
,
gemm0_b_block_buf
,
Gemm0
::
b_block_slice_copy_step
,
s_blockwise_gemm
,
s_slash_p_thread_buf
,
num_k_block_main_loop
);
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
s_slash_p_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
s_slash_p_thread_buf
(
i
)
=
scalar
*
s_slash_p_thread_buf
[
i
];
}
});
}
else
{
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
s_slash_p_thread_buf
(
i
)
=
scalar
*
s_slash_p_thread_buf
[
i
];
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
if
(
p_z_grid
)
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
}
else
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
}
block_sync_lds
();
// wait for gemm1 LDS read
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
constexpr
index_t
num_gemm2_loop
=
MPerBlock
/
Gemm2Params_N_O_M
::
Sum_M
;
static_assert
(
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetNumOfAccess
()
==
num_gemm2_loop
,
""
);
// TODO: tune gemm2 pipeline
// dV = P_drop^T * dY
v_slash_k_grad_thread_buf
.
Clear
();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dV
// load VGrad Gemm B
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
// load VGrad Gemm A
const
auto
p_slice_idx
=
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
p_slice_idx
[
I2
],
p_slice_idx
[
I2
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I2
));
constexpr
auto
nwave_range
=
make_tuple
(
p_slice_idx
[
I3
],
p_slice_idx
[
I3
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I3
));
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
.
Run
(
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
p_slice_idx
[
I0
],
p_slice_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
s_slash_p_thread_buf
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
gemm2_a_block_buf
);
}
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// p slice window is moved by loop index
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
Gemm2
::
b_block_slice_copy_step
);
block_sync_lds
();
// sync before write
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_m0_o_m1
,
gemm2_b_block_buf
);
block_sync_lds
();
// sync before read
v_slash_k_grad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_block_buf
,
v_slash_k_grad_thread_buf
);
});
// end gemm dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
v_slash_k_grad_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_buf
);
// gemm dP
block_sync_lds
();
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
ygrad_grid_desc_o0_m_o1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
// reuse
pgrad_gemm_tile_ygrad_blockwise_copy
,
ygrad_grid_buf
,
gemm0_a_block_buf
,
// reuse
Gemm0
::
a_block_slice_copy_step
,
// reuse
v_grid_desc_o0_n_o1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
// reuse
pgrad_gemm_tile_v_blockwise_copy
,
v_grid_buf
,
gemm0_b_block_buf
,
// reuse
Gemm0
::
b_block_slice_copy_step
,
// reuse
pgrad_blockwise_gemm
,
pgrad_thread_buf
,
num_o_block_main_loop
);
// dS = P * (dP - Y_dot_dY)
auto
&
sgrad_thread_buf
=
pgrad_thread_buf
;
constexpr
auto
pgrad_thread_tile_iterator
=
pgrad_blockwise_gemm
.
MakeCThreadTileIterator
();
constexpr
auto
pgrad_thread_idx_to_m_n_adaptor
=
pgrad_blockwise_gemm
.
MakeCThreadIndexAdaptor8DTo2D
();
static_for
<
0
,
pgrad_thread_tile_iterator
.
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
pgrad_thread_idx
=
pgrad_thread_tile_iterator
.
GetIndex
(
i
);
constexpr
auto
m
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
// dS and P has same thread buf layout
if
(
s_slash_p_thread_buf
[
i
]
>=
0
)
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
}
else
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}];
}
});
// gemm dQ
// dQ = scalar * dS * K
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm1
::
b_block_slice_copy_step
);
block_sync_lds
();
// wait for previous LDS read
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
Gemm1
::
b_block_desc_bk0_n_bk1
,
gemm1_b_block_buf
);
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
Gemm1
::
a_src_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
i
,
sgrad_thread_buf
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
gemm1_a_thread_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
block_sync_lds
();
qgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_block_buf
,
qgrad_thread_buf
);
block_sync_lds
();
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm1
::
b_block_slice_copy_step
);
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
Gemm1
::
b_block_desc_bk0_n_bk1
,
gemm1_b_block_buf
);
});
}
// tail
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
Gemm1
::
a_src_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
Number
<
num_gemm1_k_block_inner_loop
-
1
>
{},
sgrad_thread_buf
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
gemm1_a_thread_buf
);
block_sync_lds
();
qgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_block_buf
,
qgrad_thread_buf
);
}
}
// end gemm dQ
// dK = scalar * dS^T * dQ
v_slash_k_grad_thread_buf
.
Clear
();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dK
// load KGrad Gemm B
kgrad_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_m0_k_m1
,
q_grid_buf
);
// load KGrad Gemm A
const
auto
sgrad_slice_idx
=
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
sgrad_slice_idx
[
I2
],
sgrad_slice_idx
[
I2
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I2
));
constexpr
auto
nwave_range
=
make_tuple
(
sgrad_slice_idx
[
I3
],
sgrad_slice_idx
[
I3
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I3
));
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
.
Run
(
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
sgrad_slice_idx
[
I0
],
sgrad_slice_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
gemm2_a_block_buf
);
}
// kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
Gemm2
::
b_block_slice_copy_step
);
block_sync_lds
();
// sync before write
kgrad_gemm_tile_q_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_m0_o_m1
,
gemm2_b_block_buf
);
block_sync_lds
();
// sync before read
v_slash_k_grad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_block_buf
,
v_slash_k_grad_thread_buf
);
});
// end gemm dK
// atomic_add dK
kgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
v_slash_k_grad_thread_buf
,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_grid_buf
);
// move slice window
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_k0_m_k1
,
s_gemm_tile_a_block_reset_copy_step
);
// rewind K
s_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_k0_n_k1
,
s_gemm_tile_b_block_reset_copy_step
);
// rewind K and step N
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
Gemm2
::
b_block_reset_copy_step
);
// rewind M
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
v_grid_desc_o0_n_o1
,
pgrad_gemm_tile_v_block_reset_copy_step
);
// rewind O and step N
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
Gemm2
::
b_block_reset_copy_step
);
// rewind M
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle dQ and write
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
Gemm1NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
qgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
qgrad_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I5
);
constexpr
auto
N3
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I6
);
constexpr
auto
N4
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
)),
// M2 = MPerXdl
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
,
// N2 * N3 * N4 = NPerXdl
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
qgrad_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
SElementwiseOperation
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I3
],
n_thread_data_on_block_idx
[
I4
]),
s_element_op
};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
Gemm1NXdlPerWave
,
1
,
1
,
1
,
N2
,
1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
1
,
N2
,
1
,
N4
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
Gemm1NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
qgrad_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
qgrad_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
c_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
66052232
...
@@ -35,6 +35,7 @@ template <typename FloatAB,
...
@@ -35,6 +35,7 @@ template <typename FloatAB,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
LSEGridDesc_M
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -97,6 +98,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -97,6 +98,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
// Gemm0
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
...
@@ -116,6 +119,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -116,6 +119,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
@@ -323,6 +385,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -323,6 +385,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -367,6 +432,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -367,6 +432,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -379,6 +445,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -379,6 +445,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
@@ -782,6 +850,79 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -782,6 +850,79 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
///////////////////=>z for dropout
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// registerNum
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
n4
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
///////////////////=>z for dropout
do
do
{
{
auto
n_block_data_idx_on_grid
=
auto
n_block_data_idx_on_grid
=
...
@@ -876,8 +1017,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -876,8 +1017,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
if
constexpr
(
IsDropout
)
// dropout
{
{
blockwise_dropout
.
ApplyDropout
(
acc_thread_buf
,
ph
);
// save z to global
if
(
p_z_grid
)
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
}
}
else
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
true
>(
acc_thread_buf
,
ph
);
}
}
//if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_max_new
=
mathext
::
max
(
max
,
running_max
);
...
...
include/ck/utility/data_type.hpp
View file @
66052232
...
@@ -1010,6 +1010,42 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
...
@@ -1010,6 +1010,42 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return
uint16_t
(
u
.
int32
>>
16
);
return
uint16_t
(
u
.
int32
>>
16
);
}
}
// convert fp16 to bf16
template
<
>
inline
__host__
__device__
bhalf_t
type_convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
static_cast
<
float
>
(
x
)};
return
uint16_t
(
u
.
int32
>>
16
);
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
half2_t
>
(
half2_t
x
)
{
float
y0
{
0
},
y1
{
0
};
bhalf2_t
y
{
0
};
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1
\n
\
"
:
"=v"
(
y0
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1
\n
\
"
:
"=v"
(
y1
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1]
\n
\
"
:
"=v"
(
y
)
:
"v"
(
y0
),
"v"
(
y1
));
return
y
;
}
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
{
{
...
...
include/ck/utility/philox_rand.hpp
View file @
66052232
...
@@ -109,12 +109,9 @@ class philox
...
@@ -109,12 +109,9 @@ class philox
__device__
uint2
u32_high_low_multi
(
const
unsigned
int
a
,
const
unsigned
int
b
)
__device__
uint2
u32_high_low_multi
(
const
unsigned
int
a
,
const
unsigned
int
b
)
{
{
uint2
*
res
;
uint2
*
res
;
uint2
tmp_res
;
unsigned
long
long
tmp
;
asm
(
"v_mul_hi_u32 %0, %2, %3
\n\t
"
tmp
=
static_cast
<
unsigned
long
long
>
(
a
)
*
b
;
"v_mul_lo_u32 %1, %2, %3
\n\t
"
res
=
reinterpret_cast
<
uint2
*>
(
&
tmp
);
:
"=v"
(
tmp_res
.
x
),
"=v"
(
tmp_res
.
y
)
:
"v"
(
a
),
"v"
(
b
));
res
=
&
tmp_res
;
return
*
res
;
return
*
res
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment