Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ea38d3d2
Commit
ea38d3d2
authored
Jun 25, 2022
by
Tri Dao
Browse files
Fix race condition in backward pass (smem_dq)
parent
eeca63a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
0 deletions
+3
-0
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+3
-0
No files found.
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
ea38d3d2
...
...
@@ -475,6 +475,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
// Swizzle the elements and do the final reduction.
// Need to syncthreads here, otherwise the smem_dq reads from the previous iteration
// might happen after the smem_dq writes in this iteration.
__syncthreads
();
smem_dq
.
store
(
acc_dq
,
0
);
typename
Smem_tile_dot
::
Fragment
frag_dot
[
2
][
Mma_tile_dkv
::
MMAS_N
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment