Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f81c1bb0
Unverified
Commit
f81c1bb0
authored
Aug 01, 2025
by
Michael Goin
Committed by
GitHub
Aug 01, 2025
Browse files
[Bugfix] Check NVIDIA artifactory is accessible before using flashinfer cubin kernels (#21893)
parent
fb0e0d46
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
99 deletions
+93
-99
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+2
-44
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+80
-1
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+3
-46
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+8
-8
No files found.
vllm/attention/backends/flashinfer.py
View file @
f81c1bb0
...
...
@@ -44,9 +44,9 @@ from vllm.attention.layer import Attention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
from
vllm.utils.flashinfer
import
use_trtllm_decode_attention
logger
=
init_logger
(
__name__
)
...
...
@@ -56,7 +56,6 @@ if TYPE_CHECKING:
class
FlashInferBackend
(
AttentionBackend
):
cached_sm100a_supported
:
Optional
[
bool
]
=
None
@
staticmethod
def
get_name
()
->
str
:
...
...
@@ -123,47 +122,6 @@ class FlashInferBackend(AttentionBackend):
else
:
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
@
staticmethod
def
use_trtllm_decode_attention
(
batch_size
:
int
,
max_seq_len
:
int
,
kv_cache_dtype
:
str
,
num_qo_heads
:
Optional
[
int
],
num_kv_heads
:
Optional
[
int
],
attn_head_size
:
Optional
[
int
],
)
->
bool
:
if
FlashInferBackend
.
cached_sm100a_supported
is
None
:
FlashInferBackend
.
cached_sm100a_supported
=
(
current_platform
.
has_device_capability
(
100
))
if
not
FlashInferBackend
.
cached_sm100a_supported
:
return
False
# Check if the dimensions are supported by TRTLLM decode attention
if
(
attn_head_size
is
None
or
num_qo_heads
is
None
or
num_kv_heads
is
None
or
num_qo_heads
//
num_kv_heads
>
8
or
num_qo_heads
%
num_kv_heads
!=
0
or
attn_head_size
!=
128
):
return
False
env_value
=
envs
.
VLLM_USE_TRTLLM_DECODE_ATTENTION
if
env_value
is
not
None
:
logger
.
info_once
(
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s"
,
env_value
)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm
=
(
env_value
==
"0"
)
if
not
no_use_trtllm
:
logger
.
info_once
(
"Using TRTLLM decode attention."
)
return
not
no_use_trtllm
else
:
# Environment variable not set - use auto-detection
use_trtllm
=
(
FlashInferBackend
.
cached_sm100a_supported
and
batch_size
<=
256
and
max_seq_len
<
131072
and
kv_cache_dtype
==
"auto"
)
if
use_trtllm
:
logger
.
warning_once
(
"Using TRTLLM decode attention (auto-detected)."
)
return
use_trtllm
@
dataclass
class
PerLayerParameters
:
...
...
@@ -1156,7 +1114,7 @@ class FlashInferImpl(AttentionImpl):
assert
decode_meta
.
decode_wrapper
.
_sm_scale
==
softmax_scale
# TODO: @pavanimajety Remove this once the switch happens
# inside flashinfer.
if
not
FlashInferBackend
.
use_trtllm_decode_attention
(
if
not
use_trtllm_decode_attention
(
num_decode_tokens
,
attn_metadata
.
max_decode_seq_len
,
kv_cache_dtype
,
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
):
...
...
vllm/utils/flashinfer.py
View file @
f81c1bb0
...
...
@@ -10,12 +10,25 @@ import contextlib
import
functools
import
importlib
import
importlib.util
from
typing
import
Any
,
Callable
,
NoReturn
import
os
from
typing
import
Any
,
Callable
,
NoReturn
,
Optional
import
requests
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
FLASHINFER_CUBINS_REPOSITORY
=
os
.
environ
.
get
(
"FLASHINFER_CUBINS_REPOSITORY"
,
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/"
,
# noqa: E501
)
@
functools
.
cache
def
has_flashinfer
()
->
bool
:
...
...
@@ -108,6 +121,70 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
return
True
@
functools
.
cache
def
has_nvidia_artifactory
()
->
bool
:
"""Return ``True`` if NVIDIA's artifactory is accessible.
This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA.
"""
try
:
# Use a short timeout to avoid blocking for too long
response
=
requests
.
get
(
FLASHINFER_CUBINS_REPOSITORY
,
timeout
=
5
)
accessible
=
response
.
status_code
==
200
if
accessible
:
logger
.
debug_once
(
"NVIDIA artifactory is accessible"
)
else
:
logger
.
warning_once
(
"NVIDIA artifactory returned failed status code: %d"
,
response
.
status_code
)
return
accessible
except
Exception
as
e
:
logger
.
warning_once
(
"Failed to connect to NVIDIA artifactory: %s"
,
e
)
return
False
def
use_trtllm_decode_attention
(
num_tokens
:
int
,
max_seq_len
:
int
,
kv_cache_dtype
:
str
,
num_qo_heads
:
Optional
[
int
],
num_kv_heads
:
Optional
[
int
],
attn_head_size
:
Optional
[
int
],
)
->
bool
:
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if
not
(
current_platform
.
is_device_capability
(
100
)
and
has_nvidia_artifactory
()):
return
False
# Check if the dimensions are supported by TRTLLM decode attention
if
(
attn_head_size
is
None
or
num_qo_heads
is
None
or
num_kv_heads
is
None
or
num_qo_heads
//
num_kv_heads
>
8
or
num_qo_heads
%
num_kv_heads
!=
0
or
attn_head_size
!=
128
):
return
False
env_value
=
envs
.
VLLM_USE_TRTLLM_DECODE_ATTENTION
if
env_value
is
not
None
:
logger
.
info_once
(
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s"
,
env_value
)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm
=
(
env_value
==
"0"
)
if
not
no_use_trtllm
:
logger
.
info_once
(
"Using TRTLLM decode attention."
)
return
not
no_use_trtllm
else
:
# Environment variable not set - use auto-detection
use_trtllm
=
(
num_tokens
<=
256
and
max_seq_len
<
131072
and
kv_cache_dtype
==
"auto"
)
if
use_trtllm
:
logger
.
warning_once
(
"Using TRTLLM decode attention (auto-detected)."
)
return
use_trtllm
__all__
=
[
"has_flashinfer"
,
"flashinfer_trtllm_fp8_block_scale_moe"
,
...
...
@@ -117,4 +194,6 @@ __all__ = [
"autotune"
,
"has_flashinfer_moe"
,
"has_flashinfer_cutlass_fused_moe"
,
"has_nvidia_artifactory"
,
"use_trtllm_decode_attention"
,
]
vllm/v1/attention/backends/flashinfer.py
View file @
f81c1bb0
...
...
@@ -17,8 +17,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType
)
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.utils.flashinfer
import
use_trtllm_decode_attention
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
...
...
@@ -38,7 +38,6 @@ logger = init_logger(__name__)
class
FlashInferBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
cached_sm100a_supported
:
Optional
[
bool
]
=
None
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
...
...
@@ -98,48 +97,6 @@ class FlashInferBackend(AttentionBackend):
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
@
staticmethod
def
use_trtllm_decode_attention
(
batch_size
:
int
,
max_seq_len
:
int
,
kv_cache_dtype
:
str
,
num_qo_heads
:
int
,
num_kv_heads
:
int
,
attn_head_size
:
int
,
)
->
bool
:
if
FlashInferBackend
.
cached_sm100a_supported
is
None
:
FlashInferBackend
.
cached_sm100a_supported
=
(
current_platform
.
has_device_capability
(
100
))
if
not
FlashInferBackend
.
cached_sm100a_supported
:
return
False
if
(
num_qo_heads
//
num_kv_heads
>
8
or
num_qo_heads
%
num_kv_heads
!=
0
or
attn_head_size
!=
128
):
return
False
env_value
=
envs
.
VLLM_USE_TRTLLM_DECODE_ATTENTION
if
env_value
is
not
None
:
logger
.
info_once
(
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s"
,
env_value
)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm
=
env_value
==
"0"
if
not
no_use_trtllm
:
logger
.
info_once
(
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
"using TRTLLM decode attention."
)
return
not
no_use_trtllm
else
:
# Environment variable not set - use auto-detection
# Only supports attention head size of 128
use_trtllm
=
(
FlashInferBackend
.
cached_sm100a_supported
and
batch_size
<=
256
and
max_seq_len
<
131072
and
kv_cache_dtype
==
"auto"
)
if
use_trtllm
:
logger
.
warning_once
(
"Using TRTLLM decode attention (auto-detected)."
)
return
use_trtllm
@
staticmethod
def
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
if
kv_cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
):
...
...
@@ -352,7 +309,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
num_decodes
>
0
:
attn_metadata
.
decode_wrapper
=
self
.
_get_decode_wrapper
()
if
not
FlashInferBackend
.
use_trtllm_decode_attention
(
if
not
use_trtllm_decode_attention
(
num_decodes
,
attn_metadata
.
max_seq_len
,
self
.
cache_config
.
cache_dtype
,
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
...
...
@@ -636,7 +593,7 @@ class FlashInferImpl(AttentionImpl):
decode_query
=
query
[:
num_decode_tokens
]
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
decode_wrapper
is
not
None
if
not
FlashInferBackend
.
use_trtllm_decode_attention
(
if
not
use_trtllm_decode_attention
(
attn_metadata
.
num_decodes
,
attn_metadata
.
max_seq_len
,
self
.
kv_cache_dtype
,
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
):
...
...
vllm/v1/attention/backends/mla/common.py
View file @
f81c1bb0
...
...
@@ -209,6 +209,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.utils.flashinfer
import
has_nvidia_artifactory
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_per_layer_parameters
,
infer_global_hyperparameters
,
...
...
@@ -379,17 +380,16 @@ M = TypeVar("M", bound=MLACommonMetadata)
def
use_flashinfer_prefill
()
->
bool
:
if
flashinfer_available
and
not
envs
.
VLLM_USE_CUDNN_PREFILL
:
# For blackwell default to flashinfer prefill if its available since
# its faster than FA2.
return
current_platform
.
has_device_capability
(
100
)
return
False
# For blackwell default to flashinfer prefill if its available since
# it is faster than FA2.
return
(
flashinfer_available
and
not
envs
.
VLLM_USE_CUDNN_PREFILL
and
current_platform
.
is_device_capability
(
100
))
def
use_cudnn_prefill
()
->
bool
:
if
flashinfer_available
and
envs
.
VLLM_USE_CUDNN_PREFILL
:
return
current_platform
.
ha
s_device_capability
(
100
)
return
False
return
(
flashinfer_available
and
envs
.
VLLM_USE_CUDNN_PREFILL
and
current_platform
.
i
s_device_capability
(
100
)
and
has_nvidia_artifactory
())
# Currently 394MB, this can be tuned based on GEMM sizes used.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment