Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0dba17c3
Commit
0dba17c3
authored
Sep 11, 2023
by
letaoqin
Browse files
output d0 grad form bwd prototype 1
parent
35379cdb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
157 additions
and
37 deletions
+157
-37
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+27
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+128
-31
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
0dba17c3
...
@@ -25,7 +25,7 @@ Kernel outputs:
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
64
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -284,7 +284,7 @@ int run(int argc, char* argv[])
...
@@ -284,7 +284,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.
0
;
float
p_drop
=
0.
9
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
0dba17c3
...
@@ -65,6 +65,7 @@ __global__ void
...
@@ -65,6 +65,7 @@ __global__ void
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -120,12 +121,20 @@ __global__ void
...
@@ -120,12 +121,20 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
D0DataType
*
tmp_p_d0grad_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
if
(
p_d0_grid
!=
nullptr
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
}
if
(
p_d0grad_grid
!=
nullptr
)
{
tmp_p_d0grad_grid
=
p_d0grad_grid
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
@@ -141,6 +150,7 @@ __global__ void
...
@@ -141,6 +150,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -178,6 +188,7 @@ __global__ void
...
@@ -178,6 +188,7 @@ __global__ void
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -212,6 +223,7 @@ __global__ void
...
@@ -212,6 +223,7 @@ __global__ void
ignore
=
p_ygrad_grid
;
ignore
=
p_ygrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_qgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_kgrad_grid
;
ignore
=
p_d0grad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
p_vgrad_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -755,6 +767,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -755,6 +767,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -790,6 +804,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -790,6 +804,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_qgrad_grid_
{
p_qgrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_d0grad_grid_
{
p_d0grad_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
...
@@ -839,10 +854,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -839,10 +854,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_drop_
{
p_drop
}
p_drop_
{
p_drop
}
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_
acc0_bias
;
ignore
=
p_
d1grad_grid
;
ignore
=
p_acc1_bias
;
ignore
=
p_acc1_bias
;
ignore
=
acc0_bias_gs_ms_ns_lengths
;
ignore
=
acc0_bias_gs_ms_ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
...
@@ -926,6 +939,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -926,6 +939,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_qgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_kgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
OutputDataType
*
p_vgrad_grid_
;
D0DataType
*
p_d0grad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -1049,6 +1063,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1049,6 +1063,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg
.
p_ygrad_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_d0grad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -1200,6 +1215,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1200,6 +1215,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1237,6 +1254,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1237,6 +1254,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
,
p_acc0_bias
,
p_acc1_bias
,
p_acc1_bias
,
p_d0grad_grid
,
p_d1grad_grid
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1278,6 +1297,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1278,6 +1297,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1316,6 +1337,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1316,6 +1337,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
const
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
0dba17c3
...
@@ -1286,7 +1286,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1286,7 +1286,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0
Loade
r
struct
D0
Operato
r
{
{
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
TypeTransform
struct
TypeTransform
...
@@ -1306,13 +1306,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1306,13 +1306,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert
(
MPerXdl
<=
KPerBlock
);
static_assert
(
MPerXdl
<=
KPerBlock
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
()
__host__
__device__
static
constexpr
auto
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
()
{
{
// B1 matrix in LDS memory, dst of blockwise copy
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
}
__host__
__device__
static
constexpr
auto
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
()
__host__
__device__
static
constexpr
auto
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
()
{
{
constexpr
auto
d0_raw_m0_n_m1
=
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
...
@@ -1327,10 +1327,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1327,10 +1327,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
return
d0_n0_n1_m0_m1_m2
;
}
}
static
constexpr
auto
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
=
GetD0Block
Write
Descriptor_M0_N0_M1_M2_N1_M3
();
GetD0Block
Global
Descriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
read
_desc_n0_n1_m0_m1_m2
=
static
constexpr
auto
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
=
GetD0Block
Read
Descriptor_N0_N1_M0_M1_M2
();
GetD0Block
Vgpr
Descriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
...
@@ -1351,7 +1351,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1351,7 +1351,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
...
@@ -1367,13 +1367,56 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1367,13 +1367,56 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
D0ThreadCopy
=
using
D0ThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
read
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
2
,
// SrcScalarPerVector
2
>
;
2
>
;
using
D0ThreadCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0_block_vgpr_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
4
,
// VectorDim
4
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
using
D0BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_global_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
};
};
struct
SharedMemTrait
struct
SharedMemTrait
...
@@ -1417,11 +1460,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1417,11 +1460,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
static
constexpr
auto
d0_block_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
q_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
sizeof
(
GemmDataType
)
/
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
@@ -1441,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1441,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
sizeof
(
FloatGemmAcc
);
sizeof
(
FloatGemmAcc
);
const
index_t
d0_bytes_end
=
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0
Loade
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
D0
Operato
r
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
...
@@ -1465,6 +1509,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1465,6 +1509,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_qgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
OutputDataType
*
__restrict__
p_kgrad_grid
,
D0DataType
*
__restrict__
p_d0grad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
OutputDataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -2010,17 +2055,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2010,17 +2055,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// gemm0 M loop
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0
Loade
r
::
D0BlockwiseCopy
(
auto
d0_block_copy_global_to_lds
=
typename
D0
Operato
r
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
D0
Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Loade
r
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0
Operato
r
::
D0ThreadCopy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0ThreadCopyVgprToLds
(
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0_block_copy_lds_to_global
=
typename
D0Operator
::
D0BlockwiseCopyLdsToGlobal
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
block_sync_lds
();
block_sync_lds
();
...
@@ -2202,10 +2260,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2202,10 +2260,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0
Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0
Loade
r
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0
Operato
r
::
d0_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
// load data to lds
...
@@ -2216,13 +2274,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2216,13 +2274,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
d0_block_copy_global_to_lds
.
RunWrite
(
D0
Loade
r
::
d0_block_
write
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0
Operato
r
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
block_sync_lds
();
// read data form lds
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0
Loade
r
::
d0_block_
read
_desc_n0_n1_m0_m1_m2
,
d0_thread_copy_lds_to_vgpr
.
Run
(
D0
Operato
r
::
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_block_buf
,
D0
Loade
r
::
d0_thread_desc_
,
D0
Operato
r
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_buf
);
...
@@ -2328,6 +2386,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2328,6 +2386,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
});
// output bias grad
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
p_d0grad_grid
!=
nullptr
)
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
block_sync_lds
();
// write data from lds to global
d0_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
I0
);
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
// gemm dV
// gemm dV
// dV = P_drop^T * dY
// dV = P_drop^T * dY
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment