Unverified Commit f81c1bb0 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Check NVIDIA artifactory is accessible before using flashinfer cubin kernels (#21893)

parent fb0e0d46
...@@ -44,9 +44,9 @@ from vllm.attention.layer import Attention ...@@ -44,9 +44,9 @@ from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
from vllm.utils.flashinfer import use_trtllm_decode_attention
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -56,7 +56,6 @@ if TYPE_CHECKING: ...@@ -56,7 +56,6 @@ if TYPE_CHECKING:
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
cached_sm100a_supported: Optional[bool] = None
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -123,47 +122,6 @@ class FlashInferBackend(AttentionBackend): ...@@ -123,47 +122,6 @@ class FlashInferBackend(AttentionBackend):
else: else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@staticmethod
def use_trtllm_decode_attention(
batch_size: int,
max_seq_len: int,
kv_cache_dtype: str,
num_qo_heads: Optional[int],
num_kv_heads: Optional[int],
attn_head_size: Optional[int],
) -> bool:
if FlashInferBackend.cached_sm100a_supported is None:
FlashInferBackend.cached_sm100a_supported = (
current_platform.has_device_capability(100))
if not FlashInferBackend.cached_sm100a_supported:
return False
# Check if the dimensions are supported by TRTLLM decode attention
if (attn_head_size is None or num_qo_heads is None
or num_kv_heads is None or num_qo_heads // num_kv_heads > 8
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
return False
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm = (env_value == "0")
if not no_use_trtllm:
logger.info_once("Using TRTLLM decode attention.")
return not no_use_trtllm
else:
# Environment variable not set - use auto-detection
use_trtllm = (FlashInferBackend.cached_sm100a_supported
and batch_size <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once(
"Using TRTLLM decode attention (auto-detected).")
return use_trtllm
@dataclass @dataclass
class PerLayerParameters: class PerLayerParameters:
...@@ -1156,7 +1114,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1156,7 +1114,7 @@ class FlashInferImpl(AttentionImpl):
assert decode_meta.decode_wrapper._sm_scale == softmax_scale assert decode_meta.decode_wrapper._sm_scale == softmax_scale
# TODO: @pavanimajety Remove this once the switch happens # TODO: @pavanimajety Remove this once the switch happens
# inside flashinfer. # inside flashinfer.
if not FlashInferBackend.use_trtllm_decode_attention( if not use_trtllm_decode_attention(
num_decode_tokens, attn_metadata.max_decode_seq_len, num_decode_tokens, attn_metadata.max_decode_seq_len,
kv_cache_dtype, attn_metadata.num_qo_heads, kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.num_kv_heads, attn_metadata.head_dim):
......
...@@ -10,12 +10,25 @@ import contextlib ...@@ -10,12 +10,25 @@ import contextlib
import functools import functools
import importlib import importlib
import importlib.util import importlib.util
from typing import Any, Callable, NoReturn import os
from typing import Any, Callable, NoReturn, Optional
import requests
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
"FLASHINFER_CUBINS_REPOSITORY",
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
)
@functools.cache @functools.cache
def has_flashinfer() -> bool: def has_flashinfer() -> bool:
...@@ -108,6 +121,70 @@ def has_flashinfer_cutlass_fused_moe() -> bool: ...@@ -108,6 +121,70 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
return True return True
@functools.cache
def has_nvidia_artifactory() -> bool:
"""Return ``True`` if NVIDIA's artifactory is accessible.
This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA.
"""
try:
# Use a short timeout to avoid blocking for too long
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
accessible = response.status_code == 200
if accessible:
logger.debug_once("NVIDIA artifactory is accessible")
else:
logger.warning_once(
"NVIDIA artifactory returned failed status code: %d",
response.status_code)
return accessible
except Exception as e:
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
return False
def use_trtllm_decode_attention(
num_tokens: int,
max_seq_len: int,
kv_cache_dtype: str,
num_qo_heads: Optional[int],
num_kv_heads: Optional[int],
attn_head_size: Optional[int],
) -> bool:
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100)
and has_nvidia_artifactory()):
return False
# Check if the dimensions are supported by TRTLLM decode attention
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
or num_qo_heads // num_kv_heads > 8
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
return False
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm = (env_value == "0")
if not no_use_trtllm:
logger.info_once("Using TRTLLM decode attention.")
return not no_use_trtllm
else:
# Environment variable not set - use auto-detection
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once(
"Using TRTLLM decode attention (auto-detected).")
return use_trtllm
__all__ = [ __all__ = [
"has_flashinfer", "has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe", "flashinfer_trtllm_fp8_block_scale_moe",
...@@ -117,4 +194,6 @@ __all__ = [ ...@@ -117,4 +194,6 @@ __all__ = [
"autotune", "autotune",
"has_flashinfer_moe", "has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory",
"use_trtllm_decode_attention",
] ]
...@@ -17,8 +17,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -17,8 +17,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType) AttentionType)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.utils.flashinfer import use_trtllm_decode_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
...@@ -38,7 +38,6 @@ logger = init_logger(__name__) ...@@ -38,7 +38,6 @@ logger = init_logger(__name__)
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
cached_sm100a_supported: Optional[bool] = None
@classmethod @classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_dtypes(cls) -> list[torch.dtype]:
...@@ -98,48 +97,6 @@ class FlashInferBackend(AttentionBackend): ...@@ -98,48 +97,6 @@ class FlashInferBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.") raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order return stride_order
@staticmethod
def use_trtllm_decode_attention(
batch_size: int,
max_seq_len: int,
kv_cache_dtype: str,
num_qo_heads: int,
num_kv_heads: int,
attn_head_size: int,
) -> bool:
if FlashInferBackend.cached_sm100a_supported is None:
FlashInferBackend.cached_sm100a_supported = (
current_platform.has_device_capability(100))
if not FlashInferBackend.cached_sm100a_supported:
return False
if (num_qo_heads // num_kv_heads > 8
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
return False
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm = env_value == "0"
if not no_use_trtllm:
logger.info_once(
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
"using TRTLLM decode attention.")
return not no_use_trtllm
else:
# Environment variable not set - use auto-detection
# Only supports attention head size of 128
use_trtllm = (FlashInferBackend.cached_sm100a_supported
and batch_size <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once(
"Using TRTLLM decode attention (auto-detected).")
return use_trtllm
@staticmethod @staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"): if kv_cache_dtype in ("fp8", "fp8_e4m3"):
...@@ -352,7 +309,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -352,7 +309,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if num_decodes > 0: if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper() attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention( if not use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len, num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype, self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
...@@ -636,7 +593,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -636,7 +593,7 @@ class FlashInferImpl(AttentionImpl):
decode_query = query[:num_decode_tokens] decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None assert decode_wrapper is not None
if not FlashInferBackend.use_trtllm_decode_attention( if not use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len, attn_metadata.num_decodes, attn_metadata.max_seq_len,
self.kv_cache_dtype, attn_metadata.num_qo_heads, self.kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.num_kv_heads, attn_metadata.head_dim):
......
...@@ -209,6 +209,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -209,6 +209,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, AttentionMetadataBuilder, CommonAttentionMetadata,
get_per_layer_parameters, infer_global_hyperparameters, get_per_layer_parameters, infer_global_hyperparameters,
...@@ -379,17 +380,16 @@ M = TypeVar("M", bound=MLACommonMetadata) ...@@ -379,17 +380,16 @@ M = TypeVar("M", bound=MLACommonMetadata)
def use_flashinfer_prefill() -> bool: def use_flashinfer_prefill() -> bool:
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
# For blackwell default to flashinfer prefill if its available since # For blackwell default to flashinfer prefill if its available since
# its faster than FA2. # it is faster than FA2.
return current_platform.has_device_capability(100) return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL
return False and current_platform.is_device_capability(100))
def use_cudnn_prefill() -> bool: def use_cudnn_prefill() -> bool:
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL: return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL
return current_platform.has_device_capability(100) and current_platform.is_device_capability(100)
return False and has_nvidia_artifactory())
# Currently 394MB, this can be tuned based on GEMM sizes used. # Currently 394MB, this can be tuned based on GEMM sizes used.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment