Commit a8811d9b authored by Nathan Chen's avatar Nathan Chen Committed by LeiWang1999
Browse files

[Bugfix] Fixed mha_bwd shape inconsistency error (#604)

parent 3b52738d
...@@ -176,8 +176,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -176,8 +176,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dv = T.alloc_fragment([block_M, dim], accum_dtype) dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype) dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype) dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_N, dim], dtype) dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_N, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout({ T.annotate_layout({
dQ: make_dq_layout(dQ), dQ: make_dq_layout(dQ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment