Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
522d8b2f
Commit
522d8b2f
authored
Sep 12, 2023
by
letaoqin
Browse files
bias grad update to light version
parent
9dc3e49b
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
464 additions
and
154 deletions
+464
-154
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+8
-4
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+5
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+27
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+27
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+28
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+28
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+170
-68
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+167
-64
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+0
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
522d8b2f
...
...
@@ -518,8 +518,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
nullptr
,
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
...
...
@@ -564,8 +566,10 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
nullptr
,
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
522d8b2f
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
32
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -616,6 +616,8 @@ int run(int argc, char* argv[])
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
...
...
@@ -663,6 +665,8 @@ int run(int argc, char* argv[])
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
522d8b2f
...
...
@@ -123,6 +123,7 @@ __global__ void
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
...
...
@@ -176,12 +177,20 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
...
@@ -197,6 +206,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -233,6 +243,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -266,6 +277,7 @@ __global__ void
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
...
...
@@ -858,6 +870,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -894,6 +908,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
...
...
@@ -948,10 +963,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
ignore
=
p_
acc0_bias
;
ignore
=
p_
d1grad_grid
;
ignore
=
p_acc1_bias
;
ignore
=
acc0_bias_gs_ms_ns_lengths
;
ignore
=
acc0_bias_gs_ms_ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
...
...
@@ -1030,6 +1043,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
...
@@ -1191,6 +1205,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -1342,6 +1357,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1380,6 +1397,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_vgrad_grid
,
p_acc0_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
@@ -1422,6 +1441,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1461,6 +1482,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
522d8b2f
...
...
@@ -123,6 +123,7 @@ __global__ void
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
...
...
@@ -176,12 +177,20 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
{
...
...
@@ -198,6 +207,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -234,6 +244,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -267,6 +278,7 @@ __global__ void
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
...
...
@@ -874,6 +886,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -910,6 +924,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
...
...
@@ -964,6 +979,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
// TODO: implement bias addition
ignore
=
p_acc1_bias
;
ignore
=
p_d1grad_grid
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
...
...
@@ -1042,6 +1058,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
...
@@ -1207,6 +1224,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -1374,6 +1392,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1412,6 +1432,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_vgrad_grid
,
p_acc0_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
@@ -1454,6 +1476,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
void
*
p_vgrad_grid
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1493,6 +1517,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
522d8b2f
...
...
@@ -1333,8 +1333,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void
*
p_vgrad_grid
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1373,8 +1373,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
const
D1DataType
*>
(
p_d1grad_grid
),
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
522d8b2f
...
...
@@ -162,13 +162,16 @@ __global__ void
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
tmp_p_d0grad_grid
=
arg_ptr
[
group_id
].
p_d0grad_grid_
+
d0_batch_offset
;
}
if
constexpr
(
Deterministic
)
{
...
...
@@ -185,6 +188,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -222,6 +226,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
...
...
@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
0
==
p_acc1_bias_vec
.
size
()
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
()
==
0
))
&&
0
==
p_d1grads
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
...
...
@@ -937,6 +948,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_d0grad_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
==
group_count_
)
?
static_cast
<
D0DataType
*>
(
p_d0grads
[
i
])
:
nullptr
;
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
...
@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_ygrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_d0grad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
...
...
@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads
,
p_acc0_bias_vec
,
p_acc1_bias_vec
,
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
a_element_op
,
b_element_op
,
...
...
@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads
,
p_acc0_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
a_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
522d8b2f
...
...
@@ -160,13 +160,17 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
tmp_p_d0grad_grid
=
arg_ptr
[
group_id
].
p_d0grad_grid_
+
d0_batch_offset
;
}
if
constexpr
(
Deterministic
)
...
...
@@ -184,6 +188,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -221,6 +226,7 @@ __global__ void
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
a_element_op
,
...
...
@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
InputDataType
*
p_ygrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
...
...
@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
0
==
p_acc1_bias_vec
.
size
()
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
()
==
0
))
&&
0
==
p_d1grads
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
...
...
@@ -1007,6 +1019,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
p_ygrad_grid
=
static_cast
<
const
InputDataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Kgrads
[
i
]);
auto
p_d0grad_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_d0grads
.
size
())
==
group_count_
)
?
static_cast
<
D0DataType
*>
(
p_d0grads
[
i
])
:
nullptr
;
auto
p_vgrad_grid
=
static_cast
<
OutputDataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
...
@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_ygrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_d0grad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
...
...
@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads
,
p_acc0_bias_vec
,
p_acc1_bias_vec
,
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
a_element_op
,
b_element_op
,
...
...
@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias_vec
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias_vec
,
const
std
::
vector
<
void
*>&
p_d0grads
,
const
std
::
vector
<
void
*>&
p_d1grads
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads
,
p_acc0_bias_vec
,
// cast in struct Argument
p_acc1_bias_vec
,
// cast in struct Argument
p_d0grads
,
p_d1grads
,
problem_desc_vec
,
a_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
522d8b2f
...
...
@@ -1215,7 +1215,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0
Loade
r
struct
D0
Operato
r
{
template
<
typename
DataType
>
struct
TypeTransform
...
...
@@ -1235,13 +1235,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_assert
(
MPerXdl
<=
KPerBlock
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
()
__host__
__device__
static
constexpr
auto
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
__host__
__device__
static
constexpr
auto
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
()
__host__
__device__
static
constexpr
auto
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
()
{
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
...
...
@@ -1256,15 +1256,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
}
static
constexpr
auto
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
=
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
read
_desc_n0_n1_m0_m1_m2
=
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
=
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
=
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0BlockwiseCopy
GlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -1280,7 +1280,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
...
...
@@ -1296,13 +1296,56 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
D0ThreadWiseCopy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
read
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
2
>
;
using
D0ThreadCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0_block_vgpr_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
4
,
// VectorDim
4
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
using
D0BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_global_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
};
struct
SharedMemTrait
...
...
@@ -1337,11 +1380,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
q_block_space_size_aligned
.
value
;
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
sizeof
(
GemmDataType
)
/
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1358,7 +1402,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
sizeof
(
GemmDataType
);
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
...
...
@@ -1381,6 +1425,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
...
...
@@ -1848,17 +1893,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0
Loade
r
::
D0BlockwiseCopy
(
auto
d0_block_copy_global_to_lds
=
typename
D0
Operato
r
::
D0BlockwiseCopy
GlobalToLds
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
D0
Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Loade
r
::
D0ThreadWiseCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Operato
r
::
D0ThreadWiseCopy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0ThreadCopyVgprToLds
(
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0_block_copy_lds_to_global
=
typename
D0Operator
::
D0BlockwiseCopyLdsToGlobal
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
...
...
@@ -1993,6 +2051,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0_grid
!=
nullptr
)
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
...
...
@@ -2001,10 +2061,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Loade
r
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operato
r
::
d0_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
...
...
@@ -2015,13 +2075,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_read_desc_n0_n1_m0_m1_m2
,
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Loade
r
::
d0_thread_desc_
,
D0Operato
r
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
...
...
@@ -2036,7 +2097,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
// P_i: = softmax(scalar * S_i:)
...
...
@@ -2127,6 +2190,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
// output bias grad
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0grad_grid
!=
nullptr
)
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
block_sync_lds
();
// write data from lds to global
d0_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
I0
);
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
// gemm dV
// dV = P_drop^T * dY
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
522d8b2f
...
...
@@ -1294,7 +1294,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0
Loade
r
struct
D0
Operato
r
{
template
<
typename
DataType
>
struct
TypeTransform
...
...
@@ -1314,13 +1314,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_assert
(
NPerXdl
==
32
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
()
__host__
__device__
static
constexpr
auto
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
__host__
__device__
static
constexpr
auto
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
()
__host__
__device__
static
constexpr
auto
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
()
{
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
...
...
@@ -1335,15 +1335,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
}
static
constexpr
auto
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
=
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
read
_desc_n0_n1_m0_m1_m2
=
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
=
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
=
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0BlockwiseCopy
GlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -1359,7 +1359,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
...
...
@@ -1372,16 +1372,59 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true
,
// DstResetCoord
1
>
;
using
D0Thread
W
iseCopy
=
using
D0Thread
w
iseCopy
LdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
read
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
2
>
;
using
D0ThreadCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0_block_vgpr_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
4
,
// VectorDim
4
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
using
D0BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_global_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
};
struct
SharedMemTrait
...
...
@@ -1416,10 +1459,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1444,7 +1488,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof
(
GemmDataType
);
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
...
...
@@ -1472,6 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
...
...
@@ -1969,17 +2014,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0
Loade
r
::
D0BlockwiseCopy
(
auto
d0_block_copy_global_to_lds
=
typename
D0
Operato
r
::
D0BlockwiseCopy
GlobalToLds
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
D0
Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Loade
r
::
D0Thread
W
iseCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Operato
r
::
D0Thread
w
iseCopy
LdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0ThreadCopyVgprToLds
(
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0_block_copy_lds_to_global
=
typename
D0Operator
::
D0BlockwiseCopyLdsToGlobal
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
...
...
@@ -2144,6 +2202,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0_grid
!=
nullptr
)
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
...
...
@@ -2152,10 +2212,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Loade
r
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operato
r
::
d0_thread_desc_
.
GetElementSpaceSize
());
ignore
=
d0_thread_buf
;
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
...
...
@@ -2167,13 +2227,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_read_desc_n0_n1_m0_m1_m2
,
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Loade
r
::
d0_thread_desc_
,
D0Operato
r
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
...
...
@@ -2188,7 +2249,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
// P_i: = softmax(scalar * S_i:)
...
...
@@ -2395,6 +2458,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
// output bias grad
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0grad_grid
!=
nullptr
)
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
block_sync_lds
();
// write data from lds to global
d0_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
I0
);
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
522d8b2f
...
...
@@ -2151,7 +2151,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
ignore
=
d0grad_thread_copy_vgpr_to_lds
;
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment