Unverified Commit 853f9c3d authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[BugFix] Add memory order and testing script for split version GQA bwd kernel (#1100)

* [BugFix] Add memory order for split version kernel; Remove torch manual seed

* [Lint] Manual
parent 4c9da81a
......@@ -7,8 +7,6 @@ import argparse
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input
torch.manual_seed(1)
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
......@@ -525,7 +523,10 @@ def flashattn_bwd_split(batch,
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen:
T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j])
T.atomic_add(
dQ[q_start_idx + k_base * block_N + i, bx, j],
dq[i, j],
memory_order="release")
T.copy(dv, dv_shared)
for i, d in T.Parallel(block_M, dim_v):
......@@ -739,9 +740,9 @@ def main(BATCH: int = 1,
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run():
......@@ -784,8 +785,8 @@ if __name__ == "__main__":
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
# Default: use split
use_atomic = False
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
......@@ -12,6 +12,12 @@ import example_mha_fwd_bshd_wgmma_pipelined
import example_mha_fwd_varlen
import example_mha_bwd_wgmma_pipelined
import example_mha_fwd_bhsd
import example_gqa_bwd_tma_reduce_varlen
@tilelang.testing.requires_cuda
def test_example_gqa_bwd_tma_reduce_varlen():
example_gqa_bwd_tma_reduce_varlen.main()
@tilelang.testing.requires_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment