Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5782581a
Unverified
Commit
5782581a
authored
Jul 19, 2025
by
hax0r31337
Committed by
GitHub
Jul 18, 2025
Browse files
[Bugfix] Voxtral on Blackwell GPUs (RTX 50 series) (#21077)
Signed-off-by:
hax0r31337
<
liulihaocaiqwq@gmail.com
>
parent
0f199f19
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
0 deletions
+33
-0
vllm/attention/layer.py
vllm/attention/layer.py
+33
-0
No files found.
vllm/attention/layer.py
View file @
5782581a
...
...
@@ -16,6 +16,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
...
...
@@ -23,6 +24,34 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
USE_XFORMERS_OPS
=
None
def
check_xformers_availability
():
global
USE_XFORMERS_OPS
if
USE_XFORMERS_OPS
is
not
None
:
return
USE_XFORMERS_OPS
if
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
100
):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS
=
False
else
:
try
:
from
importlib.util
import
find_spec
find_spec
(
"xformers.ops"
)
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
# the warning only needs to be shown once
if
not
USE_XFORMERS_OPS
:
logger
.
warning
(
"Xformers is not available, falling back."
)
return
USE_XFORMERS_OPS
class
Attention
(
nn
.
Module
):
"""Attention layer.
...
...
@@ -314,6 +343,10 @@ class MultiHeadAttention(nn.Module):
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
PALLAS_VLLM_V1
}
else
_Backend
.
TORCH_SDPA
if
(
self
.
attn_backend
==
_Backend
.
XFORMERS
and
not
check_xformers_availability
()):
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
def
forward
(
self
,
query
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment