Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a11680cc
Unverified
Commit
a11680cc
authored
Jul 15, 2022
by
Anthony Chang
Committed by
GitHub
Jul 14, 2022
Browse files
fix standalone softmax race condition around blockwise reduction (#323)
parent
7f216620
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
3 deletions
+6
-3
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
+6
-3
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
View file @
a11680cc
...
...
@@ -250,8 +250,10 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
block_sync_lds
();
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
...
...
@@ -303,9 +305,10 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
block_sync_lds
();
// wait for reading being complete before writing to LDS
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
//
block_sync_lds();
block_sync_lds
();
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment