--- vortex/model/utils.py.orig 2026-01-19 10:41:45.455424578 +0800 +++ vortex/model/utils.py 2026-01-19 10:47:28.980582986 +0800 @@ -114,7 +114,7 @@ mmap=True, # Make sure PyTorch is not issuing a warning regarding potential # security issues. - weights_only=True, + weights_only=False, ) model.to_bfloat16_except_pr_lc(to_float32=True) --- vortex/model/attention.py.orig 2026-01-19 10:41:45.453424571 +0800 +++ vortex/model/attention.py 2026-01-19 10:47:28.981582989 +0800 @@ -26,6 +26,7 @@ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None from vortex.model.rotary import RotaryEmbedding +from flash_attn.flash_attn_interface import flash_attn_kvpacked_func as dcu_flash_attn_kvpacked_fun # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 @@ -215,16 +216,19 @@ batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] - return local_flash_attn_kvpacked_func( - q, - kv, - self.drop.p if self.training else 0.0, - causal=causal, - softmax_scale=self.softmax_scale, - alibi_slopes=self.alibi_slopes, - window_size=self.window_size, - deterministic=self.deterministic, - ) + return dcu_flash_attn_kvpacked_fun( + q, + kv, + self.drop.p if self.training else 0.0, + causal=causal, + softmax_scale=self.softmax_scale, + alibi_slopes=self.alibi_slopes, + window_size=self.window_size, + deterministic=self.deterministic, + softcap=0.0, + return_attn_probs=False, + bhsd=False + ) class SelfAttention(nn.Module): --- vortex/ops/attn_interface.py.orig 2026-01-19 10:41:45.456424582 +0800 +++ vortex/ops/attn_interface.py 2026-01-19 10:47:28.983582996 +0800 @@ -58,7 +58,7 @@ return_softmax: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( q, k, v, @@ -72,6 +72,9 @@ softcap, return_softmax, None, + False, + None, + 0.0, ) return out, softmax_lse, S_dmask, rng_state @@ -1624,5 +1627,6 @@ softcap, rotary_interleaved, num_splits, + None, ) return (out, softmax_lse) if return_softmax_lse else out