Unverified Commit f4628b43 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

missing commas and backwards return arguments (#1032)

* missing commas

* another fix
parent 8f873cc6
...@@ -286,7 +286,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -286,7 +286,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...@@ -511,7 +511,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -511,7 +511,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
...@@ -572,7 +572,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -572,7 +572,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
...@@ -580,7 +580,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -580,7 +580,7 @@ class FlashAttnFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function): class FlashAttnVarlenFunc(torch.autograd.Function):
...@@ -659,7 +659,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -659,7 +659,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
...@@ -667,7 +667,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -667,7 +667,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func( def flash_attn_qkvpacked_func(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment