Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c0c52268
Unverified
Commit
c0c52268
authored
Sep 22, 2023
by
Dan Yao
Committed by
GitHub
Sep 22, 2023
Browse files
Merge pull request #905 from ROCmSoftwarePlatform/mha-train-develop-grad-bias
flash attention output bias grad
parents
f04ec574
c88d1173
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
968 additions
and
450 deletions
+968
-450
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+8
-4
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+8
-4
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+10
-6
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+4
-0
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+5
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+4
-0
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+58
-32
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+37
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+33
-35
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+33
-32
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+33
-35
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+38
-33
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+34
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+35
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+35
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+35
-11
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+9
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+185
-76
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+178
-69
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+186
-76
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
c0c52268
...
@@ -513,8 +513,10 @@ int run(int argc, char* argv[])
...
@@ -513,8 +513,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
nullptr
,
// p_acc0_bias;
{},
// std::array<void*, 1> p_acc1_biases;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
@@ -558,8 +560,10 @@ int run(int argc, char* argv[])
...
@@ -558,8 +560,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
nullptr
,
// p_acc0_bias;
{},
// std::array<void*, 1> p_acc1_biases;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
c0c52268
...
@@ -518,8 +518,10 @@ int run(int argc, char* argv[])
...
@@ -518,8 +518,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
nullptr
,
// p_acc0_bias;
{},
// std::array<void*, 1> p_acc1_biases;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
@@ -564,8 +566,10 @@ int run(int argc, char* argv[])
...
@@ -564,8 +566,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
nullptr
,
// p_acc0_bias;
{},
// std::array<void*, 1> p_acc1_biases;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
c0c52268
...
@@ -597,8 +597,10 @@ int run(int argc, char* argv[])
...
@@ -597,8 +597,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// p_acc0_biases;
nullptr
,
// p_acc0_biases;
{},
// p_acc1_biases;
nullptr
,
// p_acc1_biases;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
@@ -686,8 +688,8 @@ int run(int argc, char* argv[])
...
@@ -686,8 +688,8 @@ int run(int argc, char* argv[])
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_fwd_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_fwd_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1>
p_acc0_bias
es
;
nullptr
,
//
p_acc0_bias;
{},
// std::array<void*, 1>
p_acc1_bias
es
;
nullptr
,
//
p_acc1_bias;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
@@ -743,8 +745,10 @@ int run(int argc, char* argv[])
...
@@ -743,8 +745,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
nullptr
,
// p_acc0_bias;
{},
// std::array<void*, 1> p_acc1_biases;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
c0c52268
...
@@ -604,6 +604,8 @@ int run(int argc, char* argv[])
...
@@ -604,6 +604,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -650,6 +652,8 @@ int run(int argc, char* argv[])
...
@@ -650,6 +652,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
c0c52268
...
@@ -24,7 +24,7 @@ Kernel outputs:
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
*/
#define USING_MASK 0
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
32
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -616,6 +616,8 @@ int run(int argc, char* argv[])
...
@@ -616,6 +616,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -663,6 +665,8 @@ int run(int argc, char* argv[])
...
@@ -663,6 +665,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
c0c52268
...
@@ -728,6 +728,8 @@ int run(int argc, char* argv[])
...
@@ -728,6 +728,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs_bwd
,
problem_descs_bwd
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -815,6 +817,8 @@ int run(int argc, char* argv[])
...
@@ -815,6 +817,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs_bwd
,
problem_descs_bwd
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
...
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
c0c52268
...
@@ -25,7 +25,7 @@ Kernel outputs:
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -57,6 +57,7 @@ using BF16 = ck::bhalf_t;
...
@@ -57,6 +57,7 @@ using BF16 = ck::bhalf_t;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
INT32
=
int32_t
;
using
U8
=
uint8_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
@@ -374,8 +375,8 @@ int run(int argc, char* argv[])
...
@@ -374,8 +375,8 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d
0
_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
d
0
_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
...
@@ -396,7 +397,7 @@ int run(int argc, char* argv[])
...
@@ -396,7 +397,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Acc0BiasDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
Acc0BiasDataType
>
d
0
_gs_ms_ns
(
d
0
_gs_ms_ns_lengths
,
d
0
_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
...
@@ -405,7 +406,7 @@ int run(int argc, char* argv[])
...
@@ -405,7 +406,7 @@ int run(int argc, char* argv[])
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d
0
_gs_ms_ns: "
<<
d
0
_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
...
@@ -420,36 +421,35 @@ int run(int argc, char* argv[])
...
@@ -420,36 +421,35 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
2
,
2
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
2
,
2
});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break
;
break
;
case
2
:
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0BiasDataType
>
{
-
0.5
,
0.5
});
d
0
_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0BiasDataType
>
{
-
0.5
,
0.5
});
break
;
break
;
case
3
:
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d
0
_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
break
;
case
4
:
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d
0
_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
break
;
case
5
:
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d
0
_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
...
@@ -457,7 +457,7 @@ int run(int argc, char* argv[])
...
@@ -457,7 +457,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d
0
_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -471,7 +471,7 @@ int run(int argc, char* argv[])
...
@@ -471,7 +471,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0,g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0,g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d
0
_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -485,7 +485,7 @@ int run(int argc, char* argv[])
...
@@ -485,7 +485,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d
0
_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d
0
_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
...
@@ -494,12 +494,14 @@ int run(int argc, char* argv[])
...
@@ -494,12 +494,14 @@ int run(int argc, char* argv[])
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0grad_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
d
0
_device_buf
.
ToDevice
(
d
0
_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -516,8 +518,10 @@ int run(int argc, char* argv[])
...
@@ -516,8 +518,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
static_cast
<
Acc0BiasDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
nullptr
,
// p_acc1_bias;
static_cast
<
Acc0BiasDataType
*>
(
d0grad_device_buf
.
GetDeviceBuffer
()),
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
@@ -529,10 +533,10 @@ int run(int argc, char* argv[])
...
@@ -529,10 +533,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d
0
_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
d
0
_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_strides,
{},
// acc1_bias_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
Scale
{
alpha
},
...
@@ -561,8 +565,10 @@ int run(int argc, char* argv[])
...
@@ -561,8 +565,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
static_cast
<
Acc0BiasDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
nullptr
,
// p_acc1_bias;
static_cast
<
Acc0BiasDataType
*>
(
d0grad_device_buf
.
GetDeviceBuffer
()),
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
@@ -574,10 +580,10 @@ int run(int argc, char* argv[])
...
@@ -574,10 +580,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d
0
_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
d
0
_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_strides,
{},
// acc1_bias_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
Scale
{
alpha
},
...
@@ -599,7 +605,7 @@ int run(int argc, char* argv[])
...
@@ -599,7 +605,7 @@ int run(int argc, char* argv[])
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
*
size_t
(
2
)
)
*
BatchCount
+
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
...
@@ -618,7 +624,7 @@ int run(int argc, char* argv[])
...
@@ -618,7 +624,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc0BiasDataType
>
d
0
_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
...
@@ -640,13 +646,13 @@ int run(int argc, char* argv[])
...
@@ -640,13 +646,13 @@ int run(int argc, char* argv[])
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d
0
_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d
0
_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
// run fwd again for y, cause z_g_m_n update
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
k_g_n_k
,
d_g_m_n
,
d
0
_g_m_n
,
v_g_n_o
,
v_g_n_o
,
alpha
,
alpha
,
s_g_m_n
,
s_g_m_n
,
...
@@ -783,14 +789,19 @@ int run(int argc, char* argv[])
...
@@ -783,14 +789,19 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
vgrad_device_buf
.
FromDevice
(
vgrad_gs_os_ns_device_result
.
mData
.
data
());
vgrad_device_buf
.
FromDevice
(
vgrad_gs_os_ns_device_result
.
mData
.
data
());
d0grad_device_buf
.
FromDevice
(
d0grad_gs_ms_ns_device_result
.
mData
.
data
());
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -818,6 +829,15 @@ int run(int argc, char* argv[])
...
@@ -818,6 +829,15 @@ int run(int argc, char* argv[])
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
});
std
::
cout
<<
"Checking qgrad:
\n
"
;
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
...
@@ -836,6 +856,12 @@ int run(int argc, char* argv[])
...
@@ -836,6 +856,12 @@ int run(int argc, char* argv[])
"error"
,
"error"
,
1e-2
,
1e-2
,
1e-2
);
1e-2
);
std
::
cout
<<
"Checking d0grad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
d0grad_gs_ms_ns_device_result
.
mData
,
d0grad_gs_ms_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
c0c52268
...
@@ -333,6 +333,7 @@ int run(int argc, char* argv[])
...
@@ -333,6 +333,7 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_lse
;
std
::
vector
<
const
void
*>
p_lse
;
std
::
vector
<
void
*>
p_qgrad
;
std
::
vector
<
void
*>
p_qgrad
;
std
::
vector
<
void
*>
p_kgrad
;
std
::
vector
<
void
*>
p_kgrad
;
std
::
vector
<
void
*>
p_d0grad
;
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
...
@@ -356,6 +357,7 @@ int run(int argc, char* argv[])
...
@@ -356,6 +357,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
Acc0BiasDataType
>>
d0grad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
ygrad_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
ygrad_tensors
;
...
@@ -369,6 +371,7 @@ int run(int argc, char* argv[])
...
@@ -369,6 +371,7 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
qgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
qgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
ygrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
ygrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
kgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
kgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
d0grad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
vgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
vgrad_tensors_device
;
std
::
size_t
group_count
=
10
;
std
::
size_t
group_count
=
10
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
...
@@ -445,12 +448,13 @@ int run(int argc, char* argv[])
...
@@ -445,12 +448,13 @@ int run(int argc, char* argv[])
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
num_byte
+=
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
BatchCount
+
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
*
size_t
(
2
))
*
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
@@ -600,6 +604,8 @@ int run(int argc, char* argv[])
...
@@ -600,6 +604,8 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
d0grad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
ygrad_tensors_device
.
emplace_back
(
...
@@ -619,6 +625,7 @@ int run(int argc, char* argv[])
...
@@ -619,6 +625,7 @@ int run(int argc, char* argv[])
p_y
.
push_back
(
y_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_y
.
push_back
(
y_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_kgrad
.
push_back
(
kgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_kgrad
.
push_back
(
kgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_d0grad
.
push_back
(
d0grad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_vgrad
.
push_back
(
vgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_vgrad
.
push_back
(
vgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_ygrad
.
push_back
(
ygrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_ygrad
.
push_back
(
ygrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_qgrad
.
push_back
(
qgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_qgrad
.
push_back
(
qgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
...
@@ -636,6 +643,8 @@ int run(int argc, char* argv[])
...
@@ -636,6 +643,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
p_d0
,
p_d0
,
{},
{},
p_d0grad
,
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -682,6 +691,8 @@ int run(int argc, char* argv[])
...
@@ -682,6 +691,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
p_d0
,
p_d0
,
{},
{},
p_d0grad
,
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -732,6 +743,7 @@ int run(int argc, char* argv[])
...
@@ -732,6 +743,7 @@ int run(int argc, char* argv[])
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
qgrad_tensors_device
[
i
]
->
SetZero
();
kgrad_tensors_device
[
i
]
->
SetZero
();
kgrad_tensors_device
[
i
]
->
SetZero
();
d0grad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
}
}
...
@@ -804,6 +816,8 @@ int run(int argc, char* argv[])
...
@@ -804,6 +816,8 @@ int run(int argc, char* argv[])
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k_tensors
[
i
].
GetStrides
());
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_tensors
[
i
].
GetLengths
(),
d0_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v_tensors
[
i
].
GetStrides
());
...
@@ -811,11 +825,14 @@ int run(int argc, char* argv[])
...
@@ -811,11 +825,14 @@ int run(int argc, char* argv[])
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k_tensors
[
i
].
GetStrides
());
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_tensors
[
i
].
GetLengths
(),
d0_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
d0grad_tensors_device
[
i
]
->
FromDevice
(
d0grad_gs_ms_ns_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -834,6 +851,14 @@ int run(int argc, char* argv[])
...
@@ -834,6 +851,14 @@ int run(int argc, char* argv[])
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
=
idx
[
1
];
...
@@ -861,6 +886,12 @@ int run(int argc, char* argv[])
...
@@ -861,6 +886,12 @@ int run(int argc, char* argv[])
"error"
,
"error"
,
1e-2
,
1e-2
,
1e-2
);
1e-2
);
std
::
cout
<<
"Checking d0grad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
d0grad_gs_ms_ns_device_result
.
mData
,
d0grad_gs_ms_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
c0c52268
...
@@ -123,6 +123,7 @@ __global__ void
...
@@ -123,6 +123,7 @@ __global__ void
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -176,11 +177,19 @@ __global__ void
...
@@ -176,11 +177,19 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
...
@@ -197,6 +206,7 @@ __global__ void
...
@@ -197,6 +206,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -233,6 +243,7 @@ __global__ void
...
@@ -233,6 +243,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -266,6 +277,7 @@ __global__ void
...
@@ -266,6 +277,7 @@ __global__ void
ignore
=
p_ygrad_grid
;
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -579,32 +591,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -579,32 +591,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC0GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -635,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -635,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
...
@@ -665,7 +652,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -665,7 +652,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
@@ -673,7 +660,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -673,7 +660,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
...
@@ -858,6 +845,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -858,6 +845,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -894,6 +883,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -894,6 +883,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_qgrad_grid_
{
p_qgrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
...
@@ -921,7 +911,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -921,7 +911,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
...
@@ -948,10 +938,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -948,10 +938,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_drop_
{
p_drop
}
p_drop_
{
p_drop
}
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_
acc0_bias
;
ignore
=
p_
d1grad_grid
;
ignore
=
p_acc1_bias
;
ignore
=
p_acc1_bias
;
ignore
=
acc0_bias_gs_ms_ns_lengths
;
ignore
=
acc0_bias_gs_ms_ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
...
@@ -962,7 +950,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -962,7 +950,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
@@ -1030,6 +1018,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1030,6 +1018,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -1191,6 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1191,6 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg
.
p_ygrad_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -1342,6 +1332,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1342,6 +1332,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1380,6 +1372,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1380,6 +1372,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_vgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
,
p_acc0_bias
,
p_acc1_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1422,6 +1416,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1422,6 +1416,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1461,6 +1457,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1461,6 +1457,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
c0c52268
...
@@ -123,6 +123,7 @@ __global__ void
...
@@ -123,6 +123,7 @@ __global__ void
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -176,11 +177,19 @@ __global__ void
...
@@ -176,11 +177,19 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
...
@@ -198,6 +207,7 @@ __global__ void
...
@@ -198,6 +207,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -234,6 +244,7 @@ __global__ void
...
@@ -234,6 +244,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -267,6 +278,7 @@ __global__ void
...
@@ -267,6 +278,7 @@ __global__ void
ignore
=
p_ygrad_grid
;
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -587,39 +599,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -587,39 +599,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
}
// Z in Gemm0 C position
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC0GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -674,7 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -674,7 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
@@ -682,7 +669,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -682,7 +669,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
...
@@ -874,6 +861,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -874,6 +861,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -910,6 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -910,6 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_qgrad_grid_
{
p_qgrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
...
@@ -936,7 +926,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -936,7 +926,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
...
@@ -964,6 +954,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -964,6 +954,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_acc1_bias
;
ignore
=
p_acc1_bias
;
ignore
=
p_d1grad_grid
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
...
@@ -974,7 +965,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -974,7 +965,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
@@ -1042,6 +1033,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1042,6 +1033,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -1207,6 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1207,6 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg
.
p_ygrad_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -1374,6 +1367,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1374,6 +1367,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1412,6 +1407,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1412,6 +1407,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_vgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
,
p_acc0_bias
,
p_acc1_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1454,6 +1451,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1454,6 +1451,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
const
void
*
p_acc1_bias
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1493,6 +1492,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1493,6 +1492,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
c0c52268
...
@@ -65,6 +65,7 @@ __global__ void
...
@@ -65,6 +65,7 @@ __global__ void
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -120,11 +121,19 @@ __global__ void
...
@@ -120,11 +121,19 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
...
@@ -141,6 +150,7 @@ __global__ void
...
@@ -141,6 +150,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -178,6 +188,7 @@ __global__ void
...
@@ -178,6 +188,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -212,6 +223,7 @@ __global__ void
...
@@ -212,6 +223,7 @@ __global__ void
ignore
=
p_ygrad_grid
;
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -514,32 +526,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -514,32 +526,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC0GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -570,12 +557,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -570,12 +557,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
@@ -583,7 +570,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -583,7 +570,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
...
@@ -755,6 +742,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -755,6 +742,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -790,6 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -790,6 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_qgrad_grid_
{
p_qgrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
...
@@ -814,7 +804,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -814,7 +804,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -839,10 +829,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -839,10 +829,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_drop_
{
p_drop
}
p_drop_
{
p_drop
}
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_
acc0_bias
;
ignore
=
p_
d1grad_grid
;
ignore
=
p_acc1_bias
;
ignore
=
p_acc1_bias
;
ignore
=
acc0_bias_gs_ms_ns_lengths
;
ignore
=
acc0_bias_gs_ms_ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
...
@@ -862,7 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -862,7 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
@@ -926,6 +914,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -926,6 +914,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -1049,6 +1038,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1049,6 +1038,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg
.
p_ygrad_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -1200,6 +1190,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1200,6 +1190,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1237,6 +1229,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1237,6 +1229,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
,
p_acc0_bias
,
p_acc1_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1278,6 +1272,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1278,6 +1272,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1316,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1316,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
const
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
c0c52268
...
@@ -65,6 +65,7 @@ __global__ void
...
@@ -65,6 +65,7 @@ __global__ void
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -120,13 +121,21 @@ __global__ void
...
@@ -120,13 +121,21 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
@@ -142,6 +151,7 @@ __global__ void
...
@@ -142,6 +151,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -179,6 +189,7 @@ __global__ void
...
@@ -179,6 +189,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -213,6 +224,7 @@ __global__ void
...
@@ -213,6 +224,7 @@ __global__ void
ignore
=
p_ygrad_grid
;
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -522,39 +534,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -522,39 +534,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
}
// Z in Gemm0 C position
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC0GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -584,7 +571,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -584,7 +571,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
@@ -592,7 +579,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -592,7 +579,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
...
@@ -771,6 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -771,6 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -806,6 +795,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -806,6 +795,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid_
{
p_qgrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
...
@@ -829,7 +819,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -829,7 +819,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -855,6 +845,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -855,6 +845,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_acc1_bias
;
ignore
=
p_acc1_bias
;
ignore
=
p_d1grad_grid
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
...
@@ -875,7 +866,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -875,7 +866,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
@@ -925,6 +916,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -925,6 +916,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
}
}
// pointers
// pointers
...
@@ -939,6 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -939,6 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -1066,6 +1062,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1066,6 +1062,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
arg
.
p_ygrad_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -1233,6 +1230,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1233,6 +1230,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1270,6 +1269,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1270,6 +1269,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
,
p_acc0_bias
,
p_acc1_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1311,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1311,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
const
void
*
p_acc1_bias
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1349,6 +1352,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1349,6 +1352,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
c0c52268
...
@@ -162,13 +162,16 @@ __global__ void
...
@@ -162,13 +162,16 @@ __global__ void
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
tmp_p_d0grad_grid
=
arg_ptr
[
group_id
].
p_d0grad_grid_
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
...
@@ -185,6 +188,7 @@ __global__ void
...
@@ -185,6 +188,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -222,6 +226,7 @@ __global__ void
...
@@ -222,6 +226,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -540,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -540,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -572,8 +577,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -572,8 +577,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
...
@@ -581,8 +586,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -581,8 +586,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
...
@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
InputDataType
*
p_ygrad_grid_
;
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
0
==
p_acc1_bias_vec
.
size
()
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
()
==
0
))
&&
0
==
p_d1grads
.
size
()))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
}
...
@@ -937,7 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -937,7 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
auto
p_d0grad_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
==
group_count_
)
?
static_cast
<
D0DataType
*>
(
p_d0grads
[
i
])
:
nullptr
;
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
@@ -983,7 +998,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -983,7 +998,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
...
@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_ygrad_grid
,
p_ygrad_grid
,
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
p_d0grad_grid
,
p_vgrad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
...
@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
p_acc0_bias_vec
,
p_acc1_bias_vec
,
p_acc1_bias_vec
,
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
// cast in struct Argument
p_acc0_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
c0c52268
...
@@ -160,13 +160,17 @@ __global__ void
...
@@ -160,13 +160,17 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
tmp_p_d0grad_grid
=
arg_ptr
[
group_id
].
p_d0grad_grid_
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
...
@@ -184,6 +188,7 @@ __global__ void
...
@@ -184,6 +188,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -221,6 +226,7 @@ __global__ void
...
@@ -221,6 +226,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -602,7 +608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -602,7 +608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -634,8 +640,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -634,8 +640,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
...
@@ -643,8 +649,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -643,8 +649,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
...
@@ -682,7 +688,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -682,7 +688,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
...
@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
InputDataType
*
p_ygrad_grid_
;
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
0
==
p_acc1_bias_vec
.
size
()
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
()
==
0
))
&&
0
==
p_d1grads
.
size
()))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
}
...
@@ -1007,7 +1019,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1007,7 +1019,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
auto
p_d0grad_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
==
group_count_
)
?
static_cast
<
D0DataType
*>
(
p_d0grads
[
i
])
:
nullptr
;
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
@@ -1053,7 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1053,7 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
...
@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_ygrad_grid
,
p_ygrad_grid
,
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
p_d0grad_grid
,
p_vgrad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
...
@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
p_acc0_bias_vec
,
p_acc1_bias_vec
,
p_acc1_bias_vec
,
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
// cast in struct Argument
p_acc0_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
c0c52268
...
@@ -103,13 +103,17 @@ __global__ void
...
@@ -103,13 +103,17 @@ __global__ void
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
tmp_p_d0grad_grid
=
arg_ptr
[
group_id
].
p_d0grad_grid_
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
...
@@ -126,6 +130,7 @@ __global__ void
...
@@ -126,6 +130,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -164,6 +169,7 @@ __global__ void
...
@@ -164,6 +169,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -471,7 +477,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -471,7 +477,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -503,8 +509,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -503,8 +509,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
...
@@ -512,8 +518,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -512,8 +518,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
@@ -526,7 +532,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -526,7 +532,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
...
@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
InputDataType
*
p_ygrad_grid_
;
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
0
==
p_acc1_bias_vec
.
size
()
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
()
==
0
))
&&
0
==
p_d1grads
.
size
()))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
}
...
@@ -816,7 +828,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -816,7 +828,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
auto
p_d0grad_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
==
group_count_
)
?
static_cast
<
D0DataType
*>
(
p_d0grads
[
i
])
:
nullptr
;
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
@@ -862,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -862,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
...
@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_ygrad_grid
,
p_ygrad_grid
,
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
p_d0grad_grid
,
p_vgrad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
...
@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
p_acc0_bias_vec
,
p_acc1_bias_vec
,
p_acc1_bias_vec
,
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
// cast in struct Argument
p_acc0_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
c0c52268
...
@@ -102,13 +102,16 @@ __global__ void
...
@@ -102,13 +102,16 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
tmp_p_d0grad_grid
=
arg_ptr
[
group_id
].
p_d0grad_grid_
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
...
@@ -126,6 +129,7 @@ __global__ void
...
@@ -126,6 +129,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -164,6 +168,7 @@ __global__ void
...
@@ -164,6 +168,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -534,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -534,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
@@ -566,8 +571,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -566,8 +571,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
...
@@ -575,8 +580,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -575,8 +580,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
@@ -589,7 +594,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -589,7 +594,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
...
@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
InputDataType
*
p_ygrad_grid_
;
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
0
==
p_acc1_bias_vec
.
size
()
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
()
==
0
))
&&
0
==
p_d1grads
.
size
()))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
}
...
@@ -887,7 +898,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -887,7 +898,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
auto
p_d0grad_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
==
group_count_
)
?
static_cast
<
D0DataType
*>
(
p_d0grads
[
i
])
:
nullptr
;
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
@@ -933,7 +948,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -933,7 +948,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeC
0
GridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
...
@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_ygrad_grid
,
p_ygrad_grid
,
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
p_d0grad_grid
,
p_vgrad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
...
@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
p_acc0_bias_vec
,
p_acc1_bias_vec
,
p_acc1_bias_vec
,
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
// cast in struct Argument
p_acc0_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
c0c52268
...
@@ -119,6 +119,15 @@ struct GemmGemmPadder
...
@@ -119,6 +119,15 @@ struct GemmGemmPadder
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
OPerTile_
),
Sequence
<
PadM
,
PadO
>
{});
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
OPerTile_
),
Sequence
<
PadM
,
PadO
>
{});
}
}
// C[M, Gemm1N] = C[M, N]
template
<
typename
C0Desc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadC0Descriptor_M_N
(
const
C0Desc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
return
PadTensorDescriptor
(
c_desc_mraw_nraw
,
make_tuple
(
MPerTile_
,
NPerTile_
),
Sequence
<
PadM
,
PadN
>
{});
}
MPerTileType
MPerTile_
;
MPerTileType
MPerTile_
;
NPerTileType
NPerTile_
;
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
KPerTileType
KPerTile_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
c0c52268
...
@@ -88,6 +88,10 @@ template <typename InputDataType,
...
@@ -88,6 +88,10 @@ template <typename InputDataType,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
struct
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
{
static_assert
(
KPerBlock
==
Gemm1NPerBlock
);
static_assert
(
MPerBlock
%
Gemm1KPerBlock
==
0
);
static_assert
(
NPerBlock
%
Gemm2KPerBlock
==
0
);
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -1213,7 +1217,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1213,7 +1217,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0
Loade
r
struct
D0
Operato
r
{
{
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
TypeTransform
struct
TypeTransform
...
@@ -1229,17 +1233,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1229,17 +1233,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
};
static
constexpr
index_t
NThreadClusterLengths
=
MPerXdl
;
static
constexpr
index_t
NThreadClusterLengths
=
32
;
static_assert
(
M
PerXdl
<
=
KPerBlock
);
static_assert
(
N
PerXdl
=
=
32
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
()
__host__
__device__
static
constexpr
auto
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
()
{
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
}
__host__
__device__
static
constexpr
auto
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
()
__host__
__device__
static
constexpr
auto
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
()
{
{
constexpr
auto
d0_raw_m0_n_m1
=
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
...
@@ -1254,15 +1257,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1254,15 +1257,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
return
d0_n0_n1_m0_m1_m2
;
}
}
static
constexpr
auto
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
=
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
();
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
read
_desc_n0_n1_m0_m1_m2
=
static
constexpr
auto
d0_block_
src
_desc_n0_n1_m0_m1_m2
=
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
();
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
static
constexpr
auto
&
d0grad_block_dst_desc_n0_n1_m0_m1_m2
=
d0_block_src_desc_n0_n1_m0_m1_m2
;
static
constexpr
auto
&
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
=
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
;
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1273,34 +1281,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1273,34 +1281,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
5
,
// DstVectorDim
5
,
// DstVectorDim
D0BlockTransferSrcScalarPerVector
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
true
,
// DstResetCoord
true
,
// DstResetCoord
1
>
;
1
>
;
using
D0Thread
W
iseCopy
=
using
D0Thread
w
iseCopy
LdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
read
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
src
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
2
>
;
2
>
;
using
D0GradThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0grad_block_dst_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
4
,
// VectorDim
4
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
using
D0GradBlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
};
};
struct
SharedMemTrait
struct
SharedMemTrait
...
@@ -1335,11 +1386,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1335,11 +1386,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
q_block_space_size_aligned
.
value
;
q_block_space_size_aligned
.
value
;
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
D0
Operato
r
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
static
constexpr
auto
d0_block_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
q_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
sizeof
(
GemmDataType
)
/
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
@@ -1356,7 +1407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1356,7 +1407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
sizeof
(
GemmDataType
);
sizeof
(
GemmDataType
);
const
index_t
d0_bytes_end
=
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
...
@@ -1379,6 +1430,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1379,6 +1430,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -1846,17 +1898,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1846,17 +1898,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// gemm0 M loop
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0
Loade
r
::
D0BlockwiseCopy
(
auto
d0_block_copy_global_to_lds
=
typename
D0
Operato
r
::
D0BlockwiseCopy
GlobalToLds
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
D0
Operato
r
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Loade
r
::
D0Thread
W
iseCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Operato
r
::
D0Thread
w
iseCopy
LdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
&
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0GradThreadwiseCopyVgprToLds
(
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0grad_block_copy_lds_to_global
=
typename
D0Operator
::
D0GradBlockwiseCopyLdsToGlobal
(
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
block_sync_lds
();
block_sync_lds
();
...
@@ -1992,49 +2058,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1992,49 +2058,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// add bias
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
if
(
p_d0_grid
!=
nullptr
)
{
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0_grid_buf
);
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_buf
);
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
RunWrite
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
d0_block_copy_global_to_lds
.
RunWrite
(
// read data form lds
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_read_desc_n0_n1_m0_m1_m2
,
block_sync_lds
();
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
// read data form lds
d0_block_buf
,
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
D0Loader
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_thread_buf
);
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
// bias add
d0_thread_buf
);
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
// bias add
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
});
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
// P_i: = softmax(scalar * S_i:)
// P_i: = softmax(scalar * S_i:)
...
@@ -2125,6 +2195,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -2125,6 +2195,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
});
// output bias grad
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0grad_grid
!=
nullptr
)
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
block_sync_lds
();
// write data from lds to global
d0grad_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
I0
);
d0grad_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
d0grad_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
// gemm dV
// gemm dV
// dV = P_drop^T * dY
// dV = P_drop^T * dY
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
c0c52268
...
@@ -96,6 +96,10 @@ template <typename InputDataType,
...
@@ -96,6 +96,10 @@ template <typename InputDataType,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
struct
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
{
static_assert
(
Gemm1NPerBlock
%
KPerBlock
==
0
);
static_assert
(
MPerBlock
%
Gemm1KPerBlock
==
0
);
static_assert
(
NPerBlock
%
Gemm2KPerBlock
==
0
);
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -1292,7 +1296,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1292,7 +1296,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0
Loade
r
struct
D0
Operato
r
{
{
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
TypeTransform
struct
TypeTransform
...
@@ -1312,13 +1316,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1312,13 +1316,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_assert
(
NPerXdl
==
32
);
static_assert
(
NPerXdl
==
32
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
()
__host__
__device__
static
constexpr
auto
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
()
{
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
}
__host__
__device__
static
constexpr
auto
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
()
__host__
__device__
static
constexpr
auto
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
()
{
{
constexpr
auto
d0_raw_m0_n_m1
=
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
...
@@ -1333,15 +1336,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1333,15 +1336,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
return
d0_n0_n1_m0_m1_m2
;
}
}
static
constexpr
auto
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
=
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
();
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
read
_desc_n0_n1_m0_m1_m2
=
static
constexpr
auto
d0_block_
src
_desc_n0_n1_m0_m1_m2
=
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
();
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
static
constexpr
auto
&
d0grad_block_dst_desc_n0_n1_m0_m1_m2
=
d0_block_src_desc_n0_n1_m0_m1_m2
;
static
constexpr
auto
&
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
=
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
;
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1352,34 +1360,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1352,34 +1360,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
5
,
// DstVectorDim
5
,
// DstVectorDim
D0BlockTransferSrcScalarPerVector
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
true
,
// DstResetCoord
true
,
// DstResetCoord
1
>
;
1
>
;
using
D0Thread
W
iseCopy
=
using
D0Thread
w
iseCopy
LdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
read
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
src
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
2
>
;
2
>
;
using
D0GradThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0grad_block_dst_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
4
,
// VectorDim
4
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
using
D0GradBlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
};
};
struct
SharedMemTrait
struct
SharedMemTrait
...
@@ -1414,10 +1465,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1414,10 +1465,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
D0
Operato
r
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
@@ -1442,7 +1493,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1442,7 +1493,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof
(
GemmDataType
);
sizeof
(
GemmDataType
);
const
index_t
d0_bytes_end
=
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
...
@@ -1470,6 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1470,6 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -1967,17 +2019,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1967,17 +2019,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0
Loade
r
::
D0BlockwiseCopy
(
auto
d0_block_copy_global_to_lds
=
typename
D0
Operato
r
::
D0BlockwiseCopy
GlobalToLds
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
D0
Operato
r
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Loade
r
::
D0Thread
W
iseCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Operato
r
::
D0Thread
w
iseCopy
LdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
&
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0GradThreadwiseCopyVgprToLds
(
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0grad_block_copy_lds_to_global
=
typename
D0Operator
::
D0GradBlockwiseCopyLdsToGlobal
(
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
block_sync_lds
();
block_sync_lds
();
...
@@ -2143,50 +2209,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -2143,50 +2209,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// add bias
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
if
(
p_d0_grid
!=
nullptr
)
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operato
r
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
ignore
=
d0_thread_buf
;
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
block_sync_lds
();
// read data form lds
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_read_desc_n0_n1_m0_m1_m2
,
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_block_buf
,
D0Loader
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_buf
);
// bias add
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
});
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
// P_i: = softmax(scalar * S_i:)
// P_i: = softmax(scalar * S_i:)
...
@@ -2393,6 +2462,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -2393,6 +2462,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
});
// output bias grad
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0grad_grid
!=
nullptr
)
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
block_sync_lds
();
// write data from lds to global
d0grad_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
I0
);
d0grad_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
d0grad_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
c0c52268
...
@@ -87,6 +87,10 @@ template <typename InputDataType,
...
@@ -87,6 +87,10 @@ template <typename InputDataType,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
struct
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
{
static_assert
(
KPerBlock
==
Gemm1NPerBlock
);
static_assert
(
MPerBlock
%
Gemm1KPerBlock
==
0
);
static_assert
(
NPerBlock
%
Gemm2KPerBlock
==
0
);
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -1281,7 +1285,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1281,7 +1285,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0
Loade
r
struct
D0
Operato
r
{
{
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
TypeTransform
struct
TypeTransform
...
@@ -1297,17 +1301,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1297,17 +1301,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
};
static
constexpr
index_t
NThreadClusterLengths
=
MPerXdl
;
static
constexpr
index_t
NThreadClusterLengths
=
32
;
static_assert
(
M
PerXdl
<
=
KPerBlock
);
static_assert
(
N
PerXdl
=
=
32
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
()
__host__
__device__
static
constexpr
auto
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
()
{
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
}
__host__
__device__
static
constexpr
auto
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
()
__host__
__device__
static
constexpr
auto
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
()
{
{
constexpr
auto
d0_raw_m0_n_m1
=
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
...
@@ -1322,15 +1325,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1322,15 +1325,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
return
d0_n0_n1_m0_m1_m2
;
}
}
static
constexpr
auto
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
=
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
();
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
read
_desc_n0_n1_m0_m1_m2
=
static
constexpr
auto
d0_block_
src
_desc_n0_n1_m0_m1_m2
=
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
();
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
static
constexpr
auto
&
d0grad_block_dst_desc_n0_n1_m0_m1_m2
=
d0_block_src_desc_n0_n1_m0_m1_m2
;
static
constexpr
auto
&
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
=
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
;
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1341,34 +1349,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1341,34 +1349,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
5
,
// DstVectorDim
5
,
// DstVectorDim
D0BlockTransferSrcScalarPerVector
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
true
,
// DstResetCoord
true
,
// DstResetCoord
1
>
;
1
>
;
using
D0Thread
W
iseCopy
=
using
D0Thread
w
iseCopy
LdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
read
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
src
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
2
>
;
2
>
;
using
D0GradThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0grad_block_dst_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
4
,
// VectorDim
4
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
using
D0GradBlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
};
};
struct
SharedMemTrait
struct
SharedMemTrait
...
@@ -1412,11 +1463,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1412,11 +1463,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
D0
Operato
r
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
static
constexpr
auto
d0_block_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
q_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
sizeof
(
GemmDataType
)
/
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
@@ -1436,7 +1487,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1436,7 +1487,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
sizeof
(
FloatGemmAcc
);
sizeof
(
FloatGemmAcc
);
const
index_t
d0_bytes_end
=
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
...
@@ -1460,6 +1511,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1460,6 +1511,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -2006,18 +2058,33 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2006,18 +2058,33 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// gemm0 M loop
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0
Loade
r
::
D0BlockwiseCopy
(
auto
d0_block_copy_global_to_lds
=
typename
D0
Operato
r
::
D0BlockwiseCopy
GlobalToLds
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
D0
Operato
r
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Loade
r
::
D0Thread
W
iseCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Operato
r
::
D0Thread
w
iseCopy
LdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
&
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0GradThreadwiseCopyVgprToLds
(
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0grad_block_copy_lds_to_global
=
typename
D0Operator
::
D0GradBlockwiseCopyLdsToGlobal
(
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
block_sync_lds
();
block_sync_lds
();
...
@@ -2192,49 +2259,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2192,49 +2259,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// add bias
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
if
(
p_d0_grid
!=
nullptr
)
{
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0_grid_buf
);
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_buf
);
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
RunWrite
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
d0_block_copy_global_to_lds
.
RunWrite
(
// read data form lds
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_read_desc_n0_n1_m0_m1_m2
,
block_sync_lds
();
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
// read data form lds
d0_block_buf
,
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
D0Loader
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_thread_buf
);
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
// bias add
d0_thread_buf
);
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
// bias add
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
});
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
// P_i: = softmax(scalar * S_i:)
// P_i: = softmax(scalar * S_i:)
...
@@ -2325,6 +2396,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2325,6 +2396,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
});
// output bias grad
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0grad_grid
!=
nullptr
)
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
block_sync_lds
();
// write data from lds to global
d0grad_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
I0
);
d0grad_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
d0grad_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
// gemm dV
// gemm dV
// dV = P_drop^T * dY
// dV = P_drop^T * dY
{
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment