"vscode:/vscode.git/clone" did not exist on "9af267d247f4af341e614a9c2cf2ee5272e796a2"
Commit 06fe2294 authored by one's avatar one
Browse files

Fix vortex attention interface by adding dropout parameters and updating function signature

parent 1516fed0
......@@ -13,11 +13,11 @@
+++ vortex/model/attention.py 2026-01-19 10:47:28.981582989 +0800
@@ -26,6 +26,7 @@
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
from vortex.model.rotary import RotaryEmbedding
+from flash_attn.flash_attn_interface import flash_attn_kvpacked_func as dcu_flash_attn_kvpacked_fun
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
@@ -215,16 +216,19 @@
batch_size, seqlen_q = q.shape[0], q.shape[1]
......@@ -37,8 +37,8 @@
+ q,
+ kv,
+ self.drop.p if self.training else 0.0,
+ softmax_scale=None,
+ causal=False,
+ causal=causal,
+ softmax_scale=self.softmax_scale,
+ alibi_slopes=self.alibi_slopes,
+ window_size=self.window_size,
+ deterministic=self.deterministic,
......@@ -46,8 +46,8 @@
+ return_attn_probs=False,
+ bhsd=False
+ )
class SelfAttention(nn.Module):
--- vortex/ops/attn_interface.py.orig 2026-01-19 10:41:45.456424582 +0800
+++ vortex/ops/attn_interface.py 2026-01-19 10:47:28.983582996 +0800
......@@ -60,3 +60,20 @@
q,
k,
v,
@@ -72,6 +72,9 @@
softcap,
return_softmax,
None,
+ False,
+ None,
+ 0.0,
)
return out, softmax_lse, S_dmask, rng_state
@@ -1624,5 +1627,6 @@
softcap,
rotary_interleaved,
num_splits,
+ None,
)
return (out, softmax_lse) if return_softmax_lse else out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment