Commit ea38d3d2 authored by Tri Dao's avatar Tri Dao
Browse files

Fix race condition in backward pass (smem_dq)

parent eeca63a7
......@@ -475,6 +475,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
static_assert(Gmem_tile_dq::LOOPS == 1);
// Swizzle the elements and do the final reduction.
// Need to syncthreads here, otherwise the smem_dq reads from the previous iteration
// might happen after the smem_dq writes in this iteration.
__syncthreads();
smem_dq.store(acc_dq, 0);
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment