Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
82347535
Commit
82347535
authored
Nov 27, 2023
by
letaoqin
Browse files
fix load y and y_grad vector size
parent
2f93e26f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
66 additions
and
55 deletions
+66
-55
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+10
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+10
-8
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+10
-8
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+10
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+7
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+4
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
...gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
+15
-14
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
82347535
...
...
@@ -852,14 +852,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
Deterministic
>
;
// GridwiseYDotYGrad
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
DDataType
,
DYGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
DMPerBlock
,
DKPerBlock
,
Gemm1NPerBlock
>
;
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
DDataType
,
DYGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
DMPerBlock
,
DKPerBlock
,
Gemm1NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
82347535
...
...
@@ -869,14 +869,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
Deterministic
>
;
// GridwiseYDotYGrad
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
DDataType
,
DYGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
DMPerBlock
,
DKPerBlock
,
Gemm1NPerBlock
>
;
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
DDataType
,
DYGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
DMPerBlock
,
DKPerBlock
,
Gemm1NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
82347535
...
...
@@ -821,14 +821,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
// GridwiseYDotYGrad
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
DDataType
,
DYGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
DMPerBlock
,
DKPerBlock
,
Gemm1NPerBlock
>
;
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
DDataType
,
DYGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
DMPerBlock
,
DKPerBlock
,
Gemm1NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
using
DBlock2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
82347535
...
...
@@ -890,14 +890,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
// GridwiseYDotYGrad
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
DDataType
,
DYGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
DMPerBlock
,
DKPerBlock
,
Gemm1NPerBlock
>
;
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
DDataType
,
DYGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
DMPerBlock
,
DKPerBlock
,
Gemm1NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
using
DBlock2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
>
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
82347535
...
...
@@ -1250,6 +1250,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
static_assert
(
SrcScalarPerVector
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
...
...
@@ -2007,9 +2008,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_O
::
Src
ScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
3
,
// SrcVectorDim
CShuffleBlockTransfer
ScalarPerVector
_NPerBlock
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
...
...
@@ -2021,9 +2022,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype
(
ygrad_thread_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
.
GetLengths
()),
Sequence
<
0
,
1
>
,
1
,
// SrcVectorDim
YDotYGrad_M_O
::
Src
ScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
1
,
// SrcVectorDim
CShuffleBlockTransfer
ScalarPerVector
_NPerBlock
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
>
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
,
ygrad_thread_data_on_block_idx
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
82347535
...
...
@@ -1240,6 +1240,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
static_assert
(
SrcScalarPerVector
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
...
...
@@ -2102,9 +2103,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_O
::
Src
ScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
3
,
// SrcVectorDim
CShuffleBlockTransfer
ScalarPerVector
_NPerBlock
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
false
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
View file @
82347535
...
...
@@ -27,7 +27,8 @@ template <typename InputDataType,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPadded
>
index_t
NPadded
,
index_t
YSrcScalarPerVector
>
struct
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -194,19 +195,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
y_thread_data_on_block_idx
;
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
InputDataType
,
FloatD
,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
decltype
(
y_thread_desc_m0_m1_n0_n1
),
decltype
(
y_thread_desc_m0_m1_n0_n1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_N
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
false
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_nblock_nperblock
,
y_thread_data_on_grid_idx
);
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
InputDataType
,
FloatD
,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
decltype
(
y_thread_desc_m0_m1_n0_n1
),
decltype
(
y_thread_desc_m0_m1_n0_n1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
Y
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
false
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_nblock_nperblock
,
y_thread_data_on_grid_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_N
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_N
::
SrcBufType
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment