Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c88d1173
Commit
c88d1173
authored
Sep 22, 2023
by
letaoqin
Browse files
change d0 operator variables name
parent
12c0f86a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
259 additions
and
243 deletions
+259
-243
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+65
-61
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+64
-60
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+65
-61
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+64
-60
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
c88d1173
...
@@ -25,7 +25,7 @@ Kernel outputs:
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
c88d1173
...
@@ -91,7 +91,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -91,7 +91,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_assert
(
KPerBlock
==
Gemm1NPerBlock
);
static_assert
(
KPerBlock
==
Gemm1NPerBlock
);
static_assert
(
MPerBlock
%
Gemm1KPerBlock
==
0
);
static_assert
(
MPerBlock
%
Gemm1KPerBlock
==
0
);
static_assert
(
NPerBlock
%
Gemm2KPerBlock
==
0
);
static_assert
(
NPerBlock
%
Gemm2KPerBlock
==
0
);
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -1257,14 +1257,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1257,14 +1257,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
return
d0_n0_n1_m0_m1_m2
;
}
}
static
constexpr
auto
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
=
static
constexpr
auto
d0_block_
src
_desc_n0_n1_m0_m1_m2
=
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
static
constexpr
auto
&
d0grad_block_dst_desc_n0_n1_m0_m1_m2
=
d0_block_src_desc_n0_n1_m0_m1_m2
;
static
constexpr
auto
&
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
=
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
;
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1276,18 +1281,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1276,18 +1281,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
5
,
// DstVectorDim
5
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
...
@@ -1295,21 +1300,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1295,21 +1300,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1
>
;
1
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
src
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
2
>
;
2
>
;
using
D0ThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
D0
Grad
ThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0_thread_desc_
),
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
decltype
(
d0
grad
_block_
dst
_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
...
@@ -1319,7 +1324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1319,7 +1324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
;
true
>
;
using
D0BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0
Grad
BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1330,18 +1335,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1330,18 +1335,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
decltype
(
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
...
@@ -1381,8 +1386,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1381,8 +1386,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
q_block_space_size_aligned
.
value
;
q_block_space_size_aligned
.
value
;
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
static
constexpr
auto
d0_block_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
q_block_space_size_aligned
.
value
)
*
...
@@ -1898,23 +1902,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1898,23 +1902,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0ThreadwiseCopyVgprToLds
(
auto
&
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0GradThreadwiseCopyVgprToLds
(
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0_block_copy_lds_to_global
=
typename
D0Operator
::
D0BlockwiseCopyLdsToGlobal
(
auto
d0
grad
_block_copy_lds_to_global
=
typename
D0Operator
::
D0
Grad
BlockwiseCopyLdsToGlobal
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -2062,7 +2067,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -2062,7 +2067,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
...
@@ -2076,16 +2081,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -2076,16 +2081,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
block_sync_lds
();
// read data form lds
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_buf
);
// bias add
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
...
@@ -2197,36 +2201,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -2197,36 +2201,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
if
(
p_d0grad_grid
!=
nullptr
)
if
(
p_d0grad_grid
!=
nullptr
)
{
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
p_d0grad_grid
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
sgrad_thread_buf
,
D0Operator
::
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0
grad
_block_
dst
_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
d0grad_block_buf
);
block_sync_lds
();
block_sync_lds
();
// write data from lds to global
// write data from lds to global
d0_block_copy_lds_to_global
.
Run
(
d0
grad
_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_block_buf
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
d0grad_grid_buf
,
I0
);
I0
);
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0
grad
_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
});
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0
grad
_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
c88d1173
...
@@ -1336,14 +1336,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1336,14 +1336,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
return
d0_n0_n1_m0_m1_m2
;
}
}
static
constexpr
auto
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
=
static
constexpr
auto
d0_block_
src
_desc_n0_n1_m0_m1_m2
=
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
static
constexpr
auto
&
d0grad_block_dst_desc_n0_n1_m0_m1_m2
=
d0_block_src_desc_n0_n1_m0_m1_m2
;
static
constexpr
auto
&
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
=
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
;
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1355,18 +1360,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1355,18 +1360,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
5
,
// DstVectorDim
5
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
...
@@ -1374,21 +1379,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1374,21 +1379,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1
>
;
1
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
src
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
2
>
;
2
>
;
using
D0ThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
D0
Grad
ThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0_thread_desc_
),
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
decltype
(
d0
grad
_block_
dst
_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
...
@@ -1398,7 +1403,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1398,7 +1403,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
;
true
>
;
using
D0BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0
Grad
BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1409,18 +1414,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1409,18 +1414,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
decltype
(
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
...
@@ -1460,8 +1465,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1460,8 +1465,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Operator
::
template
TypeTransform
<
D0DataType
>
::
Size
;
D0Operator
::
template
TypeTransform
<
D0DataType
>
::
Size
;
...
@@ -2019,23 +2023,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -2019,23 +2023,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0ThreadwiseCopyVgprToLds
(
auto
&
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0GradThreadwiseCopyVgprToLds
(
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0_block_copy_lds_to_global
=
typename
D0Operator
::
D0BlockwiseCopyLdsToGlobal
(
auto
d0
grad
_block_copy_lds_to_global
=
typename
D0Operator
::
D0
Grad
BlockwiseCopyLdsToGlobal
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -2213,7 +2218,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -2213,7 +2218,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
...
@@ -2227,16 +2232,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -2227,16 +2232,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
block_sync_lds
();
// read data form lds
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_buf
);
// bias add
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
...
@@ -2464,36 +2468,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -2464,36 +2468,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
if
(
p_d0grad_grid
!=
nullptr
)
if
(
p_d0grad_grid
!=
nullptr
)
{
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
p_d0grad_grid
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
sgrad_thread_buf
,
D0Operator
::
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0_block_
src
_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
d0grad_block_buf
);
block_sync_lds
();
block_sync_lds
();
// write data from lds to global
// write data from lds to global
d0_block_copy_lds_to_global
.
Run
(
d0
grad
_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_block_buf
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
d0grad_grid_buf
,
I0
);
I0
);
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0
grad
_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
});
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0
grad
_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
c88d1173
...
@@ -90,7 +90,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -90,7 +90,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert
(
KPerBlock
==
Gemm1NPerBlock
);
static_assert
(
KPerBlock
==
Gemm1NPerBlock
);
static_assert
(
MPerBlock
%
Gemm1KPerBlock
==
0
);
static_assert
(
MPerBlock
%
Gemm1KPerBlock
==
0
);
static_assert
(
NPerBlock
%
Gemm2KPerBlock
==
0
);
static_assert
(
NPerBlock
%
Gemm2KPerBlock
==
0
);
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -1325,14 +1325,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1325,14 +1325,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
return
d0_n0_n1_m0_m1_m2
;
}
}
static
constexpr
auto
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
=
static
constexpr
auto
d0_block_
src
_desc_n0_n1_m0_m1_m2
=
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
static
constexpr
auto
&
d0grad_block_dst_desc_n0_n1_m0_m1_m2
=
d0_block_src_desc_n0_n1_m0_m1_m2
;
static
constexpr
auto
&
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
=
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
;
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1344,18 +1349,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1344,18 +1349,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
5
,
// DstVectorDim
5
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
...
@@ -1363,21 +1368,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1363,21 +1368,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1
>
;
1
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
src
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
2
>
;
2
>
;
using
D0ThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
D0
Grad
ThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0_thread_desc_
),
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
decltype
(
d0
grad
_block_
dst
_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
...
@@ -1387,7 +1392,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1387,7 +1392,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
;
true
>
;
using
D0BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0
Grad
BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1398,18 +1403,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1398,18 +1403,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
decltype
(
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
...
@@ -1458,8 +1463,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1458,8 +1463,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
static
constexpr
auto
d0_block_space_offset
=
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
(
k_block_space_size_aligned
.
value
+
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
q_block_space_size_aligned
.
value
)
*
...
@@ -2060,23 +2064,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2060,23 +2064,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0ThreadwiseCopyVgprToLds
(
auto
&
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0GradThreadwiseCopyVgprToLds
(
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0_block_copy_lds_to_global
=
typename
D0Operator
::
D0BlockwiseCopyLdsToGlobal
(
auto
d0
grad
_block_copy_lds_to_global
=
typename
D0Operator
::
D0
Grad
BlockwiseCopyLdsToGlobal
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -2263,7 +2268,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2263,7 +2268,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
...
@@ -2277,16 +2282,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2277,16 +2282,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
block_sync_lds
();
// read data form lds
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_buf
);
// bias add
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
...
@@ -2398,36 +2402,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -2398,36 +2402,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
if
(
p_d0grad_grid
!=
nullptr
)
if
(
p_d0grad_grid
!=
nullptr
)
{
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
p_d0grad_grid
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
sgrad_thread_buf
,
D0Operator
::
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0
grad
_block_
dst
_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
d0grad_block_buf
);
block_sync_lds
();
block_sync_lds
();
// write data from lds to global
// write data from lds to global
d0_block_copy_lds_to_global
.
Run
(
d0
grad
_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_block_buf
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
d0grad_grid_buf
,
I0
);
I0
);
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0
grad
_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
});
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0
grad
_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
c88d1173
...
@@ -1381,14 +1381,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1381,14 +1381,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
return
d0_n0_n1_m0_m1_m2
;
}
}
static
constexpr
auto
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
=
static
constexpr
auto
d0_block_
src
_desc_n0_n1_m0_m1_m2
=
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
static
constexpr
auto
&
d0grad_block_dst_desc_n0_n1_m0_m1_m2
=
d0_block_src_desc_n0_n1_m0_m1_m2
;
static
constexpr
auto
&
d0grad_block_src_desc_m0_n0_m1_m2_n1_m3
=
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
;
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1400,18 +1405,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1400,18 +1405,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
5
,
// DstVectorDim
5
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
...
@@ -1419,21 +1424,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1419,21 +1424,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1
>
;
1
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_block_
src
_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
2
>
;
2
>
;
using
D0ThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
D0
Grad
ThreadwiseCopyVgprToLds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
typename
TypeTransform
<
D0DataType
>::
Type
,
typename
TypeTransform
<
D0DataType
>::
Type
,
decltype
(
d0_thread_desc_
),
decltype
(
d0_thread_desc_
),
decltype
(
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
),
decltype
(
d0
grad
_block_
dst
_desc_n0_n1_m0_m1_m2
),
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
tensor_operation
::
element_wise
::
Scale
,
// CElementwiseOperation
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// AccessOrder
...
@@ -1443,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1443,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
;
true
>
;
using
D0BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0
Grad
BlockwiseCopyLdsToGlobal
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1454,18 +1459,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1454,18 +1459,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
decltype
(
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
),
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// DstDimAccessOrder
5
,
// SrcVectorDim
5
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
D0BlockTransferSrcScalarPerVector
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
true
,
true
,
...
@@ -1512,8 +1517,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1512,8 +1517,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
D0Operator
::
d0_block_dst_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Operator
::
template
TypeTransform
<
D0DataType
>
::
Size
;
D0Operator
::
template
TypeTransform
<
D0DataType
>
::
Size
;
...
@@ -2132,23 +2136,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -2132,23 +2136,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0ThreadwiseCopyVgprToLds
(
auto
&
d0grad_grid_desc_m0_n0_m1_m2_n1_m3
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
auto
d0grad_thread_copy_vgpr_to_lds
=
typename
D0Operator
::
D0GradThreadwiseCopyVgprToLds
(
D0Operator
::
d0grad_block_dst_desc_n0_n1_m0_m1_m2
,
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
auto
d0_block_copy_lds_to_global
=
typename
D0Operator
::
D0BlockwiseCopyLdsToGlobal
(
auto
d0
grad
_block_copy_lds_to_global
=
typename
D0Operator
::
D0
Grad
BlockwiseCopyLdsToGlobal
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -2365,7 +2370,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -2365,7 +2370,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
...
@@ -2379,16 +2384,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -2379,16 +2384,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0Operator
::
d0_block_
dst
_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
block_sync_lds
();
// read data form lds
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_buf
);
// bias add
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
...
@@ -2616,36 +2620,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -2616,36 +2620,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if
(
p_d0grad_grid
!=
nullptr
)
if
(
p_d0grad_grid
!=
nullptr
)
{
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
p_d0grad_grid
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
sgrad_thread_buf
,
D0Operator
::
d0_block_
vgpr
_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0
grad
_block_
dst
_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
d0grad_block_buf
);
block_sync_lds
();
block_sync_lds
();
// write data from lds to global
// write data from lds to global
d0_block_copy_lds_to_global
.
Run
(
d0
grad
_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0_block_
global
_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0
grad
_block_
src
_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_block_buf
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
d0grad_grid_buf
,
I0
);
I0
);
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0
grad
_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
});
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0
grad
_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0
grad
_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment