vortex_fix.patch 2.88 KB
Newer Older
one's avatar
one committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
--- vortex/model/utils.py.orig  2026-01-19 10:41:45.455424578 +0800
+++ vortex/model/utils.py       2026-01-19 10:47:28.980582986 +0800
@@ -114,7 +114,7 @@
             mmap=True,
             # Make sure PyTorch is not issuing a warning regarding potential
             # security issues.
-            weights_only=True,
+            weights_only=False,
         )
         model.to_bfloat16_except_pr_lc(to_float32=True)

--- vortex/model/attention.py.orig      2026-01-19 10:41:45.453424571 +0800
+++ vortex/model/attention.py   2026-01-19 10:47:28.981582989 +0800
@@ -26,6 +26,7 @@
 FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
16
 
one's avatar
one committed
17
18
 from vortex.model.rotary import RotaryEmbedding
+from flash_attn.flash_attn_interface import flash_attn_kvpacked_func as dcu_flash_attn_kvpacked_fun
19
20
 
 
one's avatar
one committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
 # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
@@ -215,16 +216,19 @@
             batch_size, seqlen_q = q.shape[0], q.shape[1]
             seqlen_k = kv.shape[1]
             assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
-            return local_flash_attn_kvpacked_func(
-                q,
-                kv,
-                self.drop.p if self.training else 0.0,
-                causal=causal,
-                softmax_scale=self.softmax_scale,
-                alibi_slopes=self.alibi_slopes,
-                window_size=self.window_size,
-                deterministic=self.deterministic,
-            )
+            return dcu_flash_attn_kvpacked_fun(
+                 q,
+                 kv,
+                 self.drop.p if self.training else 0.0,
40
41
+                 causal=causal,
+                 softmax_scale=self.softmax_scale,
one's avatar
one committed
42
43
44
45
46
47
48
+                 alibi_slopes=self.alibi_slopes,
+                 window_size=self.window_size,
+                 deterministic=self.deterministic,
+                 softcap=0.0,
+                 return_attn_probs=False,
+                 bhsd=False
+             )
49
50
 
 
one's avatar
one committed
51
52
53
54
55
56
57
58
59
60
61
62
 class SelfAttention(nn.Module):
--- vortex/ops/attn_interface.py.orig   2026-01-19 10:41:45.456424582 +0800
+++ vortex/ops/attn_interface.py        2026-01-19 10:47:28.983582996 +0800
@@ -58,7 +58,7 @@
     return_softmax: bool,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
-    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
+    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
         q,
         k,
         v,
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@@ -72,6 +72,9 @@
         softcap,
         return_softmax,
         None,
+        False,
+        None,
+        0.0,
     )
     return out, softmax_lse, S_dmask, rng_state
 
@@ -1624,5 +1627,6 @@
         softcap,
         rotary_interleaved,
         num_splits,
+        None,
     )
     return (out, softmax_lse) if return_softmax_lse else out