Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ff70821
Unverified
Commit
0ff70821
authored
Nov 23, 2025
by
Roger Wang
Committed by
GitHub
Nov 24, 2025
Browse files
[Core] Deprecate `xformers` (#29262)
Signed-off-by:
Roger Wang
<
hey@rogerw.io
>
parent
5253f427
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
39 additions
and
545 deletions
+39
-545
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+6
-25
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+17
-13
vllm/model_executor/models/paddleocr_vl.py
vllm/model_executor/models/paddleocr_vl.py
+0
-13
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+1
-0
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+4
-21
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+4
-27
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+3
-9
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+3
-10
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+1
-6
vllm/utils/__init__.py
vllm/utils/__init__.py
+0
-1
vllm/v1/attention/backends/xformers.py
vllm/v1/attention/backends/xformers.py
+0
-420
No files found.
vllm/model_executor/models/glm4_1v.py
View file @
0ff70821
...
...
@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
XFORMERS
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
...
...
@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x
,
_
=
self
.
qkv
(
x
)
...
...
@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer
=
rearrange
(
context_layer
,
"b s h d -> s b (h d)"
).
contiguous
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
kv_seqlen
=
None
,
device
=
q
.
device
)
context_layer
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
attn_bias
=
attn_bias
,
p
=
0
,
scale
=
None
)
context_layer
=
rearrange
(
context_layer
,
"b s h d -> s b (h d)"
).
contiguous
()
output
,
_
=
self
.
proj
(
context_layer
)
return
output
...
...
@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x_attn
=
self
.
attn
(
self
.
norm1
(
x
),
...
...
@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
x_fused_norm
,
residual
=
self
.
norm2
(
x
,
residual
=
x_attn
)
x
=
residual
+
self
.
mlp
(
x_fused_norm
)
...
...
@@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
,
)
->
tuple
[
int
|
None
,
list
[
int
]
|
None
]:
max_seqlen
,
seqlens
=
None
,
None
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
)
->
int
|
None
:
max_seqlen
=
None
if
(
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
or
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
return
max_seqlen
,
seqlens
return
max_seqlen
def
forward
(
self
,
...
...
@@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
x
=
self
.
embeddings
(
x
,
seqlens
,
grid_thw
,
image_type_ids
[:,
0
],
image_type_ids
[:,
1
]
)
...
...
@@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
# adapter
...
...
vllm/model_executor/models/keye.py
View file @
0ff70821
...
...
@@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
PretrainedConfig
from
transformers.activations
import
GELUActivation
...
...
@@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module):
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
XFORMERS
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
...
...
@@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module):
)
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
batch_size
=
q
.
shape
[
0
]
if
rope_emb
is
None
:
...
...
@@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale
=
self
.
scale
,
)
context_layer
=
rearrange
(
output
,
"(b s) ... -> b s ..."
,
b
=
batch_size
)
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
kv_seqlen
=
None
,
device
=
q
.
device
)
context_layer
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
attn_bias
=
attn_bias
,
p
=
0
,
scale
=
None
)
elif
self
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
:
outputs
=
[]
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
start_idx
=
cu_seqlens
[
i
-
1
]
end_idx
=
cu_seqlens
[
i
]
q_i
=
q
[:,
start_idx
:
end_idx
]
k_i
=
k
[:,
start_idx
:
end_idx
]
v_i
=
v
[:,
start_idx
:
end_idx
]
q_i
,
k_i
,
v_i
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
(
q_i
,
k_i
,
v_i
)
)
output_i
=
F
.
scaled_dot_product_attention
(
q_i
,
k_i
,
v_i
,
dropout_p
=
0.0
)
output_i
=
rearrange
(
output_i
,
"b h s d -> b s h d "
)
outputs
.
append
(
output_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
if
outputs
else
q
[:,
:
0
]
context_layer
=
rearrange
(
context_layer
,
"b s h d -> b s (h d)"
).
contiguous
()
...
...
vllm/model_executor/models/paddleocr_vl.py
View file @
0ff70821
...
...
@@ -38,7 +38,6 @@ from vllm.attention.layer import (
)
from
vllm.attention.ops.vit_attn_wrappers
import
(
vit_flash_attn_wrapper
,
vit_xformers_attn_wrapper
,
)
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
...
...
@@ -657,7 +656,6 @@ class SiglipAttention(nn.Module):
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
|
None
,
max_seqlen
:
torch
.
Tensor
|
None
,
seqlens
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
batch_size
,
_
,
_
=
hidden_states
.
shape
...
...
@@ -703,10 +701,6 @@ class SiglipAttention(nn.Module):
context_layer
=
rearrange
(
context_layer
,
"b s h d -> s b (h d)"
).
contiguous
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
if
seqlens
is
None
:
raise
ValueError
(
"xFormers attention backend requires seqlens tensor."
)
context_layer
=
vit_xformers_attn_wrapper
(
q
,
k
,
v
,
seqlens
)
else
:
raise
RuntimeError
(
f
"PaddleOCR-VL does not support
{
self
.
attn_backend
}
backend now."
...
...
@@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
|
None
,
max_seqlen
:
torch
.
Tensor
|
None
,
seqlens
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
...
...
@@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module):
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
XFORMERS
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
...
...
@@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module):
cu_seqlens
=
cu_seqlens
.
to
(
device
=
device
)
max_seqlen
=
None
seqlens
=
None
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
seqlens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
...
...
@@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module):
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
return
hidden_states
...
...
vllm/model_executor/models/pixtral.py
View file @
0ff70821
...
...
@@ -74,6 +74,7 @@ from .vision import (
)
try
:
# Note: vLLM does not install xformers by default.
from
xformers
import
ops
as
xops
if
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
100
):
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
0ff70821
...
...
@@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from
vllm.attention.ops.vit_attn_wrappers
import
(
vit_flash_attn_wrapper
,
vit_torch_sdpa_wrapper
,
vit_xformers_attn_wrapper
,
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
...
...
@@ -375,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
)
->
torch
.
Tensor
:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x
,
_
=
self
.
qkv
(
x
)
...
...
@@ -435,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
v
,
cu_seqlens
,
)
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
context_layer
=
vit_xformers_attn_wrapper
(
q
,
k
,
v
,
seqlens
)
output
,
_
=
self
.
proj
(
context_layer
)
return
output
...
...
@@ -448,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"cu_seqlens"
:
0
,
"rotary_pos_emb_cos"
:
0
,
"rotary_pos_emb_sin"
:
0
,
"seqlens"
:
0
,
},
mark_unbacked_dims
=
{
"seqlens"
:
0
},
enable_if
=
should_torch_compile_mm_vit
,
)
class
Qwen2_5_VisionBlock
(
nn
.
Module
):
...
...
@@ -501,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x_attn
=
self
.
attn
(
self
.
norm1
(
x
),
...
...
@@ -509,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
x_fused_norm
,
residual
=
self
.
norm2
(
x
,
residual
=
x_attn
)
x
=
residual
+
self
.
mlp
(
x_fused_norm
)
...
...
@@ -670,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
XFORMERS
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
...
...
@@ -822,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
max_seqlen
=
torch
.
zeros
([],
device
=
cu_seqlens
.
device
)
seqlens
=
torch
.
zeros
(
1
,
device
=
cu_seqlens
.
device
)
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
seqlens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
return
max_seqlen
,
seqlens
return
max_seqlen
@
staticmethod
def
invert_permutation
(
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -897,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
max_seqlen_full
,
seqlens_full
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
max_seqlen_window
,
seqlens_window
=
self
.
compute_attn_mask_seqlen
(
cu_window_seqlens
)
max_seqlen_full
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
max_seqlen_window
=
self
.
compute_attn_mask_seqlen
(
cu_window_seqlens
)
cu_seqlens
=
cu_seqlens
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
cu_window_seqlens
=
cu_window_seqlens
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
...
...
@@ -927,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
if
layer_num
in
self
.
fullatt_block_indexes
:
cu_seqlens_now
=
cu_seqlens
max_seqlen_now
=
max_seqlen_full
seqlens_now
=
seqlens_full
else
:
cu_seqlens_now
=
cu_window_seqlens
max_seqlen_now
=
max_seqlen_window
seqlens_now
=
seqlens_window
hidden_states
=
blk
(
hidden_states
,
...
...
@@ -939,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen_now
,
seqlens
=
seqlens_now
,
)
# For Qwen2.5-VL-3B, float16 will overflow at last block
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
0ff70821
...
...
@@ -348,7 +348,6 @@ class Qwen2VisionAttention(nn.Module):
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
XFORMERS
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
...
...
@@ -384,7 +383,6 @@ class Qwen2VisionAttention(nn.Module):
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
# [s, b, c] --> [s, b, 3 * head * head_dim]
x
,
_
=
self
.
qkv
(
x
)
...
...
@@ -445,20 +443,6 @@ class Qwen2VisionAttention(nn.Module):
context_layer
=
rearrange
(
context_layer
,
"b s h d -> s b (h d)"
).
contiguous
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
kv_seqlen
=
None
,
device
=
q
.
device
)
context_layer
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
attn_bias
=
attn_bias
,
p
=
0
,
scale
=
None
)
context_layer
=
rearrange
(
context_layer
,
"b s h d -> s b (h d)"
).
contiguous
()
output
,
_
=
self
.
proj
(
context_layer
)
return
output
...
...
@@ -509,7 +493,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
...
...
@@ -517,7 +500,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
...
...
@@ -728,18 +710,14 @@ class Qwen2VisionTransformer(nn.Module):
sin_combined
=
sin
[
pos_ids
].
flatten
(
1
)
return
cos_combined
,
sin_combined
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
tuple
[
int
|
None
,
list
[
int
]
|
None
]:
max_seqlen
,
seqlens
=
None
,
None
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
int
|
None
:
max_seqlen
=
None
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
return
max_seqlen
,
seqlens
return
max_seqlen
def
forward
(
self
,
...
...
@@ -771,7 +749,7 @@ class Qwen2VisionTransformer(nn.Module):
x
=
x
.
unsqueeze
(
1
)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
max_seqlen
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
cu_seqlens
=
cu_seqlens
.
to
(
self
.
device
,
non_blocking
=
True
)
for
blk
in
self
.
blocks
:
x
=
blk
(
...
...
@@ -780,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
# adapter
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
0ff70821
...
...
@@ -224,7 +224,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
...
...
@@ -232,7 +231,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
...
...
@@ -500,14 +498,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
max_seqlen
=
torch
.
zeros
([],
device
=
cu_seqlens
.
device
)
seqlens
=
torch
.
zeros
(
1
,
device
=
cu_seqlens
.
device
)
if
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
seqlens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
return
max_seqlen
,
seqlens
return
max_seqlen
def
forward
(
self
,
...
...
@@ -533,7 +528,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
rotary_pos_emb_cos
=
rotary_pos_emb_cos
.
to
(
hidden_states
.
device
)
rotary_pos_emb_sin
=
rotary_pos_emb_sin
.
to
(
hidden_states
.
device
)
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
max_seqlen
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
hidden_states_list
=
[]
deepstack_visual_indexes
=
self
.
deepstack_visual_indexes
...
...
@@ -545,7 +540,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
if
(
deepstack_visual_indexes
is
not
None
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
0ff70821
...
...
@@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
...
...
@@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
...
...
@@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
XFORMERS
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
...
...
@@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
max_seqlen
=
torch
.
zeros
([],
device
=
cu_seqlens
.
device
)
seqlens
=
torch
.
zeros
(
1
,
device
=
cu_seqlens
.
device
)
if
(
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
or
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
:
seqlens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
return
max_seqlen
,
seqlens
return
max_seqlen
def
forward
(
self
,
...
...
@@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens
=
torch
.
from_numpy
(
cu_seqlens
)
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
max_seqlen
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
cu_seqlens
=
cu_seqlens
.
to
(
self
.
device
,
non_blocking
=
True
)
deepstack_feature_lists
=
[]
...
...
@@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
if
layer_num
in
self
.
deepstack_visual_indexes
:
deepstack_merger_idx
=
self
.
deepstack_visual_indexes
.
index
(
layer_num
)
...
...
vllm/platforms/cuda.py
View file @
0ff70821
...
...
@@ -277,12 +277,7 @@ class CudaPlatformBase(Platform):
except
ImportError
:
pass
if
cls
.
has_device_capability
(
100
):
# xFormers doesn't support Blackwell, fall back to SDPA
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
return
AttentionBackendEnum
.
TORCH_SDPA
else
:
return
AttentionBackendEnum
.
XFORMERS
return
AttentionBackendEnum
.
TORCH_SDPA
@
classmethod
def
get_valid_backends
(
...
...
vllm/utils/__init__.py
View file @
0ff70821
...
...
@@ -49,7 +49,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL
:
str
=
"FLASHINFER"
STR_XFORMERS_ATTN_VAL
:
str
=
"XFORMERS"
STR_FLASH_ATTN_VAL
:
str
=
"FLASH_ATTN"
STR_INVALID_VAL
:
str
=
"INVALID"
...
...
vllm/v1/attention/backends/xformers.py
deleted
100644 → 0
View file @
5253f427
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with XFormersAttention."""
from
dataclasses
import
dataclass
from
typing
import
ClassVar
,
Optional
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionType
,
MultipleOf
,
)
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
split_decodes_and_prefills
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
try
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
(
AttentionBias
,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask
,
)
XFORMERS_AVAILABLE
=
True
except
ImportError
:
XFORMERS_AVAILABLE
=
False
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
class
XFormersAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
32
,
40
,
48
,
56
,
64
,
72
,
80
,
88
,
96
,
104
,
112
,
120
,
128
,
136
,
144
,
152
,
160
,
168
,
176
,
184
,
192
,
200
,
208
,
216
,
224
,
232
,
240
,
248
,
256
,
]
@
staticmethod
def
get_name
()
->
str
:
return
"XFORMERS"
@
staticmethod
def
get_impl_cls
()
->
type
[
"XFormersAttentionImpl"
]:
return
XFormersAttentionImpl
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_builder_cls
()
->
type
[
"XFormersAttentionMetadataBuilder"
]:
return
XFormersAttentionMetadataBuilder
@
staticmethod
def
use_cascade_attention
(
*
args
,
**
kwargs
)
->
bool
:
return
False
@
dataclass
class
XFormersAttentionMetadata
:
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
seq_lens
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
num_prefill_tokens
:
int
=
0
num_decode_tokens
:
int
=
0
num_prefills
:
int
=
0
num_decodes
:
int
=
0
# Biases for different attention types.
attn_bias
:
Optional
[
"AttentionBias"
]
=
None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata
:
Optional
[
"XFormersAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"XFormersAttentionMetadata"
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
# Recover cached prefill-phase attention
# metadata structure
return
self
.
_cached_prefill_metadata
q_start_loc
=
self
.
query_start_loc
[
self
.
num_decodes
:]
q_seqlens
=
torch
.
diff
(
q_start_loc
)
kv_seqlens
=
self
.
seq_lens
[
self
.
num_decodes
:]
# Construct & cache prefill-phase attention metadata structure
self
.
_cached_prefill_metadata
=
XFormersAttentionMetadata
(
num_actual_tokens
=
self
.
num_prefill_tokens
,
max_query_len
=
int
(
q_seqlens
.
max
().
item
()),
query_start_loc
=
q_start_loc
-
q_start_loc
[
0
],
max_seq_len
=
int
(
kv_seqlens
.
max
().
item
()),
seq_lens
=
kv_seqlens
,
block_table
=
self
.
block_table
[
self
.
num_decodes
:],
slot_mapping
=
self
.
slot_mapping
[
self
.
num_decode_tokens
:],
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"XFormersAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
# Recover cached decode-phase attention
# metadata structure
return
self
.
_cached_decode_metadata
q_start_loc
=
self
.
query_start_loc
q_seqlens
=
torch
.
diff
(
q_start_loc
)
decode_kv_seqlens
=
self
.
seq_lens
[:
self
.
num_decodes
]
# Construct & cache decode-phase attention metadata structure
self
.
_cached_decode_metadata
=
XFormersAttentionMetadata
(
num_actual_tokens
=
self
.
num_decode_tokens
,
max_query_len
=
int
(
q_seqlens
[:
self
.
num_decodes
].
max
().
item
()),
query_start_loc
=
q_start_loc
[:
self
.
num_decodes
+
1
],
max_seq_len
=
int
(
decode_kv_seqlens
.
max
().
item
()),
seq_lens
=
decode_kv_seqlens
,
block_table
=
self
.
block_table
[:
self
.
num_decodes
],
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_decode_tokens
],
attn_bias
=
self
.
attn_bias
,
)
return
self
.
_cached_decode_metadata
class
XFormersAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
XFormersAttentionMetadata
]
):
reorder_batch_threshold
:
int
=
1
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
assert
XFORMERS_AVAILABLE
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
_num_decodes
=
0
self
.
_num_decode_tokens
=
0
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
XFormersAttentionMetadata
:
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
q_start_loc
=
common_attn_metadata
.
query_start_loc
q_seqlens
=
torch
.
diff
(
q_start_loc
)
max_query_len
=
common_attn_metadata
.
max_query_len
kv_seqlens
=
common_attn_metadata
.
seq_lens
max_seq_len
=
common_attn_metadata
.
max_seq_len
block_table
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
bias
=
None
if
num_decodes
>
0
:
# Construct the decoder bias.
decode_q_seqlens
=
q_seqlens
[:
num_decodes
]
decode_kv_seqlens
=
kv_seqlens
[:
num_decodes
]
bias
=
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask
.
from_seqlens
(
q_seqlen
=
decode_q_seqlens
.
tolist
(),
kv_seqlen
=
decode_kv_seqlens
.
tolist
(),
page_size
=
self
.
block_size
,
block_tables
=
block_table
[:
num_decodes
],
device
=
block_table
.
device
,
)
return
XFormersAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_prefills
=
num_prefills
,
num_decodes
=
num_decodes
,
max_query_len
=
max_query_len
,
query_start_loc
=
q_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
kv_seqlens
,
block_table
=
block_table
,
slot_mapping
=
slot_mapping
,
attn_bias
=
bias
,
)
class
XFormersAttentionImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
list
[
float
]
|
None
,
sliding_window
:
int
|
None
,
kv_cache_dtype
:
str
,
logits_soft_cap
:
float
|
None
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
str
|
None
=
None
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"XFormers does not support alibi slopes yet."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
None
:
self
.
sliding_window
=
(
-
1
,
-
1
)
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
if
logits_soft_cap
is
None
:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"XFormersAttentionImpl."
)
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
XFormersAttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with XFormers.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for XFormersAttentionImpl"
)
if
attn_metadata
is
None
:
# Profiling run.
return
output
.
fill_
(
0
)
# Cache the input KVs.
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
descale_shape
=
(
prefill_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
unified_attention
(
q
=
query
[
num_decode_tokens
:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[
num_decode_tokens
:
num_actual_tokens
],
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
prefill_meta
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
q_descale
=
None
,
# Not supported
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[:
num_decode_tokens
]
# Reshape query to [1, B_T, G, H, D].
q
=
decode_query
.
view
(
1
,
-
1
,
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
self
.
head_size
)
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
cache_k
=
key_cache
.
view
(
1
,
-
1
,
self
.
num_kv_heads
,
1
,
self
.
head_size
).
expand
(
1
,
-
1
,
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
self
.
head_size
,
)
cache_v
=
value_cache
.
view
(
1
,
-
1
,
self
.
num_kv_heads
,
1
,
self
.
head_size
).
expand
(
1
,
-
1
,
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
self
.
head_size
,
)
attn_bias
=
decode_meta
.
attn_bias
output
[:
num_decode_tokens
]
=
xops
.
memory_efficient_attention_forward
(
q
,
cache_k
,
cache_v
,
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
).
view
(
decode_query
.
shape
)
# Reshape the output tensor.
return
output
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment