"tests/vscode:/vscode.git/clone" did not exist on "dea268336fb51b5bc342dca29189f3d4440ca2a0"
Unverified Commit a543e678 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Fix SM100 gpt-oss regression due to faulty attn sink support (#28561)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 2dacd573
...@@ -35,9 +35,20 @@ FLASHINFER_CUBINS_REPOSITORY = os.environ.get( ...@@ -35,9 +35,20 @@ FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
) )
@functools.cache
def has_flashinfer_cubin() -> bool:
"""Return `True` if flashinfer-cubin package is available."""
if envs.VLLM_HAS_FLASHINFER_CUBIN:
return True
if importlib.util.find_spec("flashinfer_cubin") is not None:
return True
logger.debug_once("flashinfer-cubin package was not found")
return False
@functools.cache @functools.cache
def has_flashinfer() -> bool: def has_flashinfer() -> bool:
"""Return `True` if FlashInfer is available.""" """Return `True` if flashinfer-python package is available."""
# Use find_spec to check if the module exists without importing it # Use find_spec to check if the module exists without importing it
# This avoids potential CUDA initialization side effects # This avoids potential CUDA initialization side effects
if importlib.util.find_spec("flashinfer") is None: if importlib.util.find_spec("flashinfer") is None:
...@@ -45,7 +56,7 @@ def has_flashinfer() -> bool: ...@@ -45,7 +56,7 @@ def has_flashinfer() -> bool:
return False return False
# When not using flashinfer cubin, # When not using flashinfer cubin,
# Also check if nvcc is available since it's required to JIT compile flashinfer # Also check if nvcc is available since it's required to JIT compile flashinfer
if not envs.VLLM_HAS_FLASHINFER_CUBIN and shutil.which("nvcc") is None: if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
logger.debug_once( logger.debug_once(
"FlashInfer unavailable since nvcc was not found " "FlashInfer unavailable since nvcc was not found "
"and not using pre-downloaded cubins" "and not using pre-downloaded cubins"
...@@ -183,9 +194,8 @@ def has_nvidia_artifactory() -> bool: ...@@ -183,9 +194,8 @@ def has_nvidia_artifactory() -> bool:
This checks connectivity to the kernel inference library artifactory This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA. which is required for downloading certain cubin kernels like TRTLLM FHMA.
""" """
# Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when # If we have pre-downloaded cubins, we can assume the cubins are available.
# it's true, we could assume the cubins are available. if has_flashinfer_cubin():
if envs.VLLM_HAS_FLASHINFER_CUBIN:
return True return True
try: try:
...@@ -208,9 +218,13 @@ def has_nvidia_artifactory() -> bool: ...@@ -208,9 +218,13 @@ def has_nvidia_artifactory() -> bool:
@functools.cache @functools.cache
def supports_trtllm_attention() -> bool: def supports_trtllm_attention() -> bool:
""" """
TRTLLM attention is supported if the platform is SM100 and TRTLLM attention is supported if the platform is SM100,
NVIDIA artifactory is accessible NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
""" """
# Batch-invariant mode disables TRTLLM attention
if vllm_is_batch_invariant():
return False
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
return current_platform.is_device_capability(100) and has_nvidia_artifactory() return current_platform.is_device_capability(100) and has_nvidia_artifactory()
...@@ -229,9 +243,6 @@ def force_use_trtllm_attention() -> bool | None: ...@@ -229,9 +243,6 @@ def force_use_trtllm_attention() -> bool | None:
return `True` if TRTLLM attention is forced to be used, return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used. return `False` if TRTLLM attention is forced to be not used.
""" """
if vllm_is_batch_invariant():
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
return False
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
......
...@@ -229,6 +229,21 @@ class FlashInferBackend(AttentionBackend): ...@@ -229,6 +229,21 @@ class FlashInferBackend(AttentionBackend):
12, 1 12, 1
) )
@classmethod
def supports_sink(cls) -> bool:
"""FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
from vllm.utils.flashinfer import (
force_use_trtllm_attention,
supports_trtllm_attention,
)
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
if force_use_trtllm_attention() is False:
return False
# Check if TRTLLM is supported on this platform
return supports_trtllm_attention()
@classmethod @classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment