Unverified Commit 87360308 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[V1] Use FlashInfer by default on Blackwell GPUs (#19118)

parent aa49f148
...@@ -229,6 +229,21 @@ class CudaPlatformBase(Platform): ...@@ -229,6 +229,21 @@ class CudaPlatformBase(Platform):
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends." return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend") "triton_attn.TritonAttentionBackend")
if cls.is_device_capability(100):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
try:
import flashinfer # noqa: F401
logger.info_once(
"Using FlashInfer backend on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.")
return ("vllm.v1.attention.backends."
"flashinfer.FlashInferBackend")
except ImportError:
logger.info_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass
if cls.has_device_capability(80): if cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends." return ("vllm.v1.attention.backends."
......
...@@ -228,6 +228,30 @@ class Platform: ...@@ -228,6 +228,30 @@ class Platform:
return current_capability.to_int() >= capability return current_capability.to_int() >= capability
@classmethod
def is_device_capability(
cls,
capability: Union[tuple[int, int], int],
device_id: int = 0,
) -> bool:
"""
Test whether this platform has exactly the specified device capability.
The `capability` argument can either be:
- A tuple `(major, minor)`.
- An integer `<major><minor>`. (See
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
if isinstance(capability, tuple):
return current_capability == capability
return current_capability.to_int() == capability
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device.""" """Get the name of a device."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment