Commit b8ccd200 authored by Tri Dao's avatar Tri Dao
Browse files

[Triton] Fix variable name from qkv to kv (h/t FrankZijlstra)

parent 05481617
...@@ -212,8 +212,8 @@ def _fwd_kernel( ...@@ -212,8 +212,8 @@ def _fwd_kernel(
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
tl.store(lse_ptrs, lse_i) tl.store(lse_ptrs, lse_i)
# initialize pointers to output # initialize pointers to output
offs_n = tl.arange(0, BLOCK_HEADDIM) offs_d = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :]) out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
if EVEN_M: if EVEN_M:
if EVEN_HEADDIM: if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o) tl.store(out_ptrs, acc_o)
...@@ -789,7 +789,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -789,7 +789,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
with torch.inference_mode(): with torch.inference_mode():
dq = torch.empty_like(q) dq = torch.empty_like(q)
dkv = torch.empty_like(kv) dkv = torch.empty_like(kv)
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse, _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse,
dq, dkv[:, :, 0], dkv[:, :, 1], dq, dkv[:, :, 0], dkv[:, :, 1],
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dq, dkv, None, None, None return dq, dkv, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment