vortex_fix.patch 2.51 KB
Newer Older
one's avatar
one committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
--- vortex/model/utils.py.orig  2026-01-19 10:41:45.455424578 +0800
+++ vortex/model/utils.py       2026-01-19 10:47:28.980582986 +0800
@@ -114,7 +114,7 @@
             mmap=True,
             # Make sure PyTorch is not issuing a warning regarding potential
             # security issues.
-            weights_only=True,
+            weights_only=False,
         )
         model.to_bfloat16_except_pr_lc(to_float32=True)

--- vortex/model/attention.py.orig      2026-01-19 10:41:45.453424571 +0800
+++ vortex/model/attention.py   2026-01-19 10:47:28.981582989 +0800
@@ -26,6 +26,7 @@
 FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None

 from vortex.model.rotary import RotaryEmbedding
+from flash_attn.flash_attn_interface import flash_attn_kvpacked_func as dcu_flash_attn_kvpacked_fun


 # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
@@ -215,16 +216,19 @@
             batch_size, seqlen_q = q.shape[0], q.shape[1]
             seqlen_k = kv.shape[1]
             assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
-            return local_flash_attn_kvpacked_func(
-                q,
-                kv,
-                self.drop.p if self.training else 0.0,
-                causal=causal,
-                softmax_scale=self.softmax_scale,
-                alibi_slopes=self.alibi_slopes,
-                window_size=self.window_size,
-                deterministic=self.deterministic,
-            )
+            return dcu_flash_attn_kvpacked_fun(
+                 q,
+                 kv,
+                 self.drop.p if self.training else 0.0,
+                 softmax_scale=None,
+                 causal=False,
+                 alibi_slopes=self.alibi_slopes,
+                 window_size=self.window_size,
+                 deterministic=self.deterministic,
+                 softcap=0.0,
+                 return_attn_probs=False,
+                 bhsd=False
+             )


 class SelfAttention(nn.Module):
--- vortex/ops/attn_interface.py.orig   2026-01-19 10:41:45.456424582 +0800
+++ vortex/ops/attn_interface.py        2026-01-19 10:47:28.983582996 +0800
@@ -58,7 +58,7 @@
     return_softmax: bool,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
-    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
+    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
         q,
         k,
         v,