Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
227076ba
Commit
227076ba
authored
Jul 06, 2023
by
ltqin
Browse files
change O to N
parent
796b544e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
18 deletions
+18
-18
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
...dwise_batched_multihead_attention_bacckward_ydotygrad.hpp
+18
-18
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
View file @
227076ba
...
@@ -115,7 +115,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -115,7 +115,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_
O
_
struct
YDotYGrad_M_
N
_
{
{
static_assert
(
BlockSize_
==
BlockSliceLength_M_
);
static_assert
(
BlockSize_
==
BlockSliceLength_M_
);
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
1
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
1
>
{};
...
@@ -134,7 +134,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -134,7 +134,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatD
,
ThreadSliceLength_M
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatD
,
ThreadSliceLength_M
,
true
>
;
};
};
using
YDotYGrad_M_
O
=
YDotYGrad_M_
O
_
<
BlockSize
,
MPerBlock
,
NPerBlock
>
;
using
YDotYGrad_M_
N
=
YDotYGrad_M_
N
_
<
BlockSize
,
MPerBlock
,
NPerBlock
>
;
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_y_grid
,
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_y_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
...
@@ -169,20 +169,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -169,20 +169,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
constexpr
auto
d_thread_desc_mblock_mrepeat_mwave_mperxdl
=
constexpr
auto
d_thread_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
));
constexpr
auto
y_thread_desc_m0_m1_
o
0_
o
1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
y_thread_desc_m0_m1_
n
0_
n
1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_
O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_
O
::
ThreadSliceLength_O
));
I1
,
YDotYGrad_M_
N
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_
N
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_
O
::
ThreadClusterLength_M
,
YDotYGrad_M_
N
::
ThreadClusterLength_M
,
I1
,
I1
,
YDotYGrad_M_
O
::
ThreadClusterLength_O
>
{},
YDotYGrad_M_
N
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
Sequence
<
0
,
1
,
2
,
3
>
{});
const
auto
y_thread_cluster_idx
=
const
auto
y_thread_cluster_idx
=
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
y_thread_data_on_block_idx
=
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_
o
0_
o
1
.
GetLengths
();
y_thread_cluster_idx
*
y_thread_desc_m0_m1_
n
0_
n
1
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
make_multi_index
(
...
@@ -194,19 +194,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -194,19 +194,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
InputDataType
,
InputDataType
,
FloatD
,
FloatD
,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
decltype
(
y_thread_desc_m0_m1_
o
0_
o
1
),
decltype
(
y_thread_desc_m0_m1_
n
0_
n
1
),
decltype
(
y_thread_desc_m0_m1_
o
0_
o
1
.
GetLengths
()),
decltype
(
y_thread_desc_m0_m1_
n
0_
n
1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
3
,
// SrcVectorDim
YDotYGrad_M_
O
::
SrcScalarPerVector
,
// SrcScalarPerVector
YDotYGrad_M_
N
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
true
/* ResetCoordAfterRun */
,
false
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_nblock_nperblock
,
false
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_nblock_nperblock
,
y_thread_data_on_grid_idx
);
y_thread_data_on_grid_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_
O
::
SrcBufType
{};
auto
y_thread_buf
=
typename
YDotYGrad_M_
N
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_
O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_
N
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_
O
::
DstBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_
N
::
DstBufType
{};
// clear accum buffers
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_thread_accum_buf
.
Clear
();
...
@@ -216,19 +216,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -216,19 +216,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
{
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_nblock_nperblock
,
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_nblock_nperblock
,
y_grid_buf
,
y_grid_buf
,
y_thread_desc_m0_m1_
o
0_
o
1
,
y_thread_desc_m0_m1_
n
0_
n
1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_thread_buf
);
y_thread_buf
);
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_nblock_nperblock
,
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_nblock_nperblock
,
ygrad_grid_buf
,
ygrad_grid_buf
,
y_thread_desc_m0_m1_
o
0_
o
1
,
y_thread_desc_m0_m1_
n
0_
n
1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
ygrad_thread_buf
);
ygrad_thread_buf
);
static_for
<
0
,
YDotYGrad_M_
O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_
N
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_
O
::
ThreadSliceLength_O
,
1
>
{}([
&
](
auto
iO
)
{
static_for
<
0
,
YDotYGrad_M_
N
::
ThreadSliceLength_O
,
1
>
{}([
&
](
auto
iO
)
{
constexpr
auto
offset
=
constexpr
auto
offset
=
y_thread_desc_m0_m1_
o
0_
o
1
.
CalculateOffset
(
make_multi_index
(
I0
,
iM
,
I0
,
iO
));
y_thread_desc_m0_m1_
n
0_
n
1
.
CalculateOffset
(
make_multi_index
(
I0
,
iM
,
I0
,
iO
));
y_dot_ygrad_thread_accum_buf
(
iM
)
+=
y_dot_ygrad_thread_accum_buf
(
iM
)
+=
y_thread_buf
[
Number
<
offset
>
{}]
*
ygrad_thread_buf
[
Number
<
offset
>
{}];
y_thread_buf
[
Number
<
offset
>
{}]
*
ygrad_thread_buf
[
Number
<
offset
>
{}];
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment