Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
592b0649
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "2d3ae4e1258791a04a28279044359c08c16af99e"
Commit
592b0649
authored
Sep 11, 2023
by
letaoqin
Browse files
grouped bwd add bias grad
parent
0dba17c3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
15 deletions
+95
-15
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+1
-1
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+38
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+28
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+28
-4
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
592b0649
...
@@ -603,7 +603,7 @@ int run(int argc, char* argv[])
...
@@ -603,7 +603,7 @@ int run(int argc, char* argv[])
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
*
size_t
(
2
)
)
*
BatchCount
+
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
592b0649
...
@@ -24,7 +24,7 @@ Kernel outputs:
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
*/
#define USING_MASK 0
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
64
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -333,6 +333,7 @@ int run(int argc, char* argv[])
...
@@ -333,6 +333,7 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_lse
;
std
::
vector
<
const
void
*>
p_lse
;
std
::
vector
<
void
*>
p_qgrad
;
std
::
vector
<
void
*>
p_qgrad
;
std
::
vector
<
void
*>
p_kgrad
;
std
::
vector
<
void
*>
p_kgrad
;
std
::
vector
<
void
*>
p_d0grad
;
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
...
@@ -356,6 +357,7 @@ int run(int argc, char* argv[])
...
@@ -356,6 +357,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
Acc0BiasDataType
>>
d0grad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
ygrad_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
ygrad_tensors
;
...
@@ -369,6 +371,7 @@ int run(int argc, char* argv[])
...
@@ -369,6 +371,7 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
qgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
qgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
ygrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
ygrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
kgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
kgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
d0grad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
vgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
vgrad_tensors_device
;
std
::
size_t
group_count
=
10
;
std
::
size_t
group_count
=
10
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
...
@@ -445,12 +448,13 @@ int run(int argc, char* argv[])
...
@@ -445,12 +448,13 @@ int run(int argc, char* argv[])
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
num_byte
+=
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
BatchCount
+
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
*
size_t
(
2
))
*
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
@@ -600,6 +604,8 @@ int run(int argc, char* argv[])
...
@@ -600,6 +604,8 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
d0grad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
ygrad_tensors_device
.
emplace_back
(
...
@@ -619,6 +625,7 @@ int run(int argc, char* argv[])
...
@@ -619,6 +625,7 @@ int run(int argc, char* argv[])
p_y
.
push_back
(
y_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_y
.
push_back
(
y_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_kgrad
.
push_back
(
kgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_kgrad
.
push_back
(
kgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_d0grad
.
push_back
(
d0grad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_vgrad
.
push_back
(
vgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_vgrad
.
push_back
(
vgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_ygrad
.
push_back
(
ygrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_ygrad
.
push_back
(
ygrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_qgrad
.
push_back
(
qgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_qgrad
.
push_back
(
qgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
...
@@ -636,6 +643,8 @@ int run(int argc, char* argv[])
...
@@ -636,6 +643,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
p_d0
,
p_d0
,
{},
{},
p_d0grad
,
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -682,6 +691,8 @@ int run(int argc, char* argv[])
...
@@ -682,6 +691,8 @@ int run(int argc, char* argv[])
p_vgrad
,
p_vgrad
,
p_d0
,
p_d0
,
{},
{},
p_d0grad
,
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -732,6 +743,7 @@ int run(int argc, char* argv[])
...
@@ -732,6 +743,7 @@ int run(int argc, char* argv[])
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
qgrad_tensors_device
[
i
]
->
SetZero
();
kgrad_tensors_device
[
i
]
->
SetZero
();
kgrad_tensors_device
[
i
]
->
SetZero
();
d0grad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
}
}
...
@@ -804,6 +816,8 @@ int run(int argc, char* argv[])
...
@@ -804,6 +816,8 @@ int run(int argc, char* argv[])
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k_tensors
[
i
].
GetStrides
());
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_tensors
[
i
].
GetLengths
(),
d0_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v_tensors
[
i
].
GetStrides
());
...
@@ -811,11 +825,14 @@ int run(int argc, char* argv[])
...
@@ -811,11 +825,14 @@ int run(int argc, char* argv[])
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k_tensors
[
i
].
GetStrides
());
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_tensors
[
i
].
GetLengths
(),
d0_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
d0grad_tensors_device
[
i
]
->
FromDevice
(
d0grad_gs_ms_ns_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -834,6 +851,14 @@ int run(int argc, char* argv[])
...
@@ -834,6 +851,14 @@ int run(int argc, char* argv[])
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
=
idx
[
1
];
...
@@ -861,6 +886,12 @@ int run(int argc, char* argv[])
...
@@ -861,6 +886,12 @@ int run(int argc, char* argv[])
"error"
,
"error"
,
1e-2
,
1e-2
,
1e-2
);
1e-2
);
std
::
cout
<<
"Checking d0grad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
d0grad_gs_ms_ns_device_result
.
mData
,
d0grad_gs_ms_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
592b0649
...
@@ -103,13 +103,17 @@ __global__ void
...
@@ -103,13 +103,17 @@ __global__ void
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
tmp_p_d0grad_grid
=
arg_ptr
[
group_id
].
p_d0grad_grid_
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
...
@@ -126,6 +130,7 @@ __global__ void
...
@@ -126,6 +130,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -164,6 +169,7 @@ __global__ void
...
@@ -164,6 +169,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
InputDataType
*
p_ygrad_grid_
;
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
0
==
p_acc1_bias_vec
.
size
()
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
()
==
0
))
&&
0
==
p_d1grads
.
size
()))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
}
...
@@ -816,7 +828,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -816,7 +828,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
auto
p_d0grad_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
==
group_count_
)
?
static_cast
<
D0DataType
*>
(
p_d0grads
[
i
])
:
nullptr
;
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_ygrad_grid
,
p_ygrad_grid
,
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
p_d0grad_grid
,
p_vgrad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
...
@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
p_acc0_bias_vec
,
p_acc1_bias_vec
,
p_acc1_bias_vec
,
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
// cast in struct Argument
p_acc0_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
592b0649
...
@@ -102,13 +102,16 @@ __global__ void
...
@@ -102,13 +102,16 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
tmp_p_d0grad_grid
=
arg_ptr
[
group_id
].
p_d0grad_grid_
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
...
@@ -126,6 +129,7 @@ __global__ void
...
@@ -126,6 +129,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -164,6 +168,7 @@ __global__ void
...
@@ -164,6 +168,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
InputDataType
*
p_ygrad_grid_
;
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
0
==
p_acc1_bias_vec
.
size
()
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
()
==
0
))
&&
0
==
p_d1grads
.
size
()))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
}
...
@@ -887,7 +898,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -887,7 +898,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
auto
p_d0grad_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
==
group_count_
)
?
static_cast
<
D0DataType
*>
(
p_d0grads
[
i
])
:
nullptr
;
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_ygrad_grid
,
p_ygrad_grid
,
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
p_d0grad_grid
,
p_vgrad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
...
@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
p_acc0_bias_vec
,
p_acc1_bias_vec
,
p_acc1_bias_vec
,
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads
,
p_Vgrads
,
p_acc0_bias_vec
,
// cast in struct Argument
p_acc0_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment