Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fa94a220
"...resnet50_tensorflow.git" did not exist on "84df9351b07f2384ee0f49cce8a789f1adbccbac"
Commit
fa94a220
authored
Sep 08, 2023
by
letaoqin
Browse files
deviece add d0 grad
parent
b7b7e153
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
6 deletions
+39
-6
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+6
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+28
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+5
-3
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
fa94a220
...
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -494,6 +494,7 @@ int run(int argc, char* argv[])
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgrad_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
...
...
@@ -518,6 +519,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
static_cast
<
Acc0BiasDataType
*>
(
dgrad_device_buf
.
GetDeviceBuffer
()),
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
...
...
@@ -563,6 +566,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
static_cast
<
Acc0BiasDataType
*>
(
dgrad_device_buf
.
GetDeviceBuffer
()),
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
fa94a220
...
...
@@ -65,6 +65,7 @@ __global__ void
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
...
...
@@ -120,13 +121,21 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
...
@@ -142,6 +151,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -179,6 +189,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -213,6 +224,7 @@ __global__ void
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
...
...
@@ -771,6 +783,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -806,6 +820,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
...
...
@@ -855,6 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
// TODO: implement bias addition
ignore
=
p_acc1_bias
;
ignore
=
p_d1grad_grid
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
...
...
@@ -939,6 +955,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
...
@@ -1066,6 +1083,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -1233,6 +1251,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1270,6 +1290,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid
,
p_acc0_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
@@ -1311,6 +1333,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void
*
p_vgrad_grid
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1349,6 +1373,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
const
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
fa94a220
...
...
@@ -1478,9 +1478,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1537,6 +1537,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
...
...
@@ -1562,6 +1563,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
index_t
raw_n_padded
,
const
index_t
block_idx_n
)
{
ignore
=
p_d0grad_grid
;
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment