Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
21ef37b4
Unverified
Commit
21ef37b4
authored
Sep 11, 2023
by
Dan Yao
Committed by
GitHub
Sep 11, 2023
Browse files
Merge pull request #889 from ROCmSoftwarePlatform/mha-train-develop-bwdopt-bias
Mha train develop bwdopt bias
parents
1f04cd2b
db579ac9
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
729 additions
and
393 deletions
+729
-393
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+722
-390
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
...gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
+7
-3
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
21ef37b4
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
View file @
21ef37b4
...
@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
y_grid_desc_mblock_mperblock_nblock_nperblock
,
y_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
float
p_drop
)
{
{
const
FloatD
p_dropout
=
type_convert
<
FloatD
>
(
1.0
f
-
p_drop
);
const
tensor_operation
::
element_wise
::
Scale
scale_p_dropout
(
p_dropout
);
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_y_grid
,
y_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
FloatD
,
FloatD
,
decltype
(
d_thread_desc_mblock_m1
),
decltype
(
d_thread_desc_mblock_m1
),
decltype
(
d_grid_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Scale
,
Sequence
<
1
,
1
>
,
Sequence
<
1
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
...
@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
d_grid_desc_mblock_mperblock
,
d_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx_m
,
// mblock
make_multi_index
(
block_work_idx_m
,
// mblock
get_thread_local_1d_id
()),
// mperblock
get_thread_local_1d_id
()),
// mperblock
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
};
scale_p_dropout
};
// copy from VGPR to Global
// copy from VGPR to Global
d_thread_copy_vgpr_to_global
.
Run
(
d_thread_desc_mblock_m1
,
d_thread_copy_vgpr_to_global
.
Run
(
d_thread_desc_mblock_m1
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment