Unverified Commit f4628b43 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

missing commas and backwards return arguments (#1032)

* missing commas

* another fix
parent 8f873cc6
......@@ -286,7 +286,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
......@@ -511,7 +511,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
......@@ -572,7 +572,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
......@@ -580,7 +580,7 @@ class FlashAttnFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function):
......@@ -659,7 +659,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
......@@ -667,7 +667,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment