Commit 81e01efd authored by Tri Dao's avatar Tri Dao
Browse files

More typo fixes

parent 72e27c63
...@@ -78,6 +78,7 @@ def _flash_attn_varlen_forward( ...@@ -78,6 +78,7 @@ def _flash_attn_varlen_forward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
return_softmax, return_softmax,
block_table, block_table,
...@@ -102,6 +103,7 @@ def _flash_attn_varlen_forward( ...@@ -102,6 +103,7 @@ def _flash_attn_varlen_forward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
softcap,
return_softmax, return_softmax,
None, None,
) )
...@@ -300,6 +302,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -300,6 +302,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
...@@ -318,6 +321,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -318,6 +321,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None, block_table=None,
...@@ -328,6 +332,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -328,6 +332,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
...@@ -355,12 +360,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -355,12 +360,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
...@@ -373,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -373,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
softcap,
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
...@@ -387,6 +394,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -387,6 +394,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
...@@ -395,6 +403,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -395,6 +403,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
...@@ -419,13 +428,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -419,13 +428,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.softcap,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic, ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
......
...@@ -303,6 +303,7 @@ def attention_kvpacked_ref( ...@@ -303,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
): ):
...@@ -318,6 +319,7 @@ def attention_kvpacked_ref( ...@@ -318,6 +319,7 @@ def attention_kvpacked_ref(
upcast=upcast, upcast=upcast,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops, reorder_ops=reorder_ops,
) )
...@@ -330,6 +332,7 @@ def attention_qkvpacked_ref( ...@@ -330,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
): ):
...@@ -345,6 +348,7 @@ def attention_qkvpacked_ref( ...@@ -345,6 +348,7 @@ def attention_qkvpacked_ref(
upcast=upcast, upcast=upcast,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops, reorder_ops=reorder_ops,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment