"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "5856e5d0f7f88f325ab2d43b3b0878d7ac676a1c"
Commit 81e01efd authored by Tri Dao's avatar Tri Dao
Browse files

More typo fixes

parent 72e27c63
......@@ -78,6 +78,7 @@ def _flash_attn_varlen_forward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
return_softmax,
block_table,
......@@ -102,6 +103,7 @@ def _flash_attn_varlen_forward(
causal,
window_size[0],
window_size[1],
softcap,
return_softmax,
None,
)
......@@ -300,6 +302,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -318,6 +321,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
......@@ -328,6 +332,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -355,12 +360,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
......@@ -373,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -387,6 +394,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
......@@ -395,6 +403,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -419,13 +428,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
......
......@@ -303,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
......@@ -318,6 +319,7 @@ def attention_kvpacked_ref(
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)
......@@ -330,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
......@@ -345,6 +348,7 @@ def attention_qkvpacked_ref(
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment