Unverified Commit 1cbccb6d authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Use `has_flashinfer` helper (#33177)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent bd92089d
...@@ -229,7 +229,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -229,7 +229,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights, get_and_maybe_dequant_weights,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_nvidia_artifactory from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
...@@ -599,13 +599,6 @@ except ImportError: ...@@ -599,13 +599,6 @@ except ImportError:
is_vllm_fa = False is_vllm_fa = False
@functools.cache
def flashinfer_available() -> bool:
import importlib.util
return importlib.util.find_spec("flashinfer") is not None
def dynamic_per_batched_tensor_quant( def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
): ):
...@@ -824,7 +817,7 @@ def use_flashinfer_prefill() -> bool: ...@@ -824,7 +817,7 @@ def use_flashinfer_prefill() -> bool:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
if not ( if not (
not vllm_config.attention_config.disable_flashinfer_prefill not vllm_config.attention_config.disable_flashinfer_prefill
and flashinfer_available() and has_flashinfer()
and not vllm_config.attention_config.use_cudnn_prefill and not vllm_config.attention_config.use_cudnn_prefill
and current_platform.is_device_capability_family(100) and current_platform.is_device_capability_family(100)
): ):
...@@ -838,7 +831,7 @@ def use_cudnn_prefill() -> bool: ...@@ -838,7 +831,7 @@ def use_cudnn_prefill() -> bool:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
return ( return (
flashinfer_available() has_flashinfer()
and vllm_config.attention_config.use_cudnn_prefill and vllm_config.attention_config.use_cudnn_prefill
and current_platform.is_device_capability_family(100) and current_platform.is_device_capability_family(100)
and has_nvidia_artifactory() and has_nvidia_artifactory()
...@@ -851,7 +844,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: ...@@ -851,7 +844,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
if not ( if not (
flashinfer_available has_flashinfer()
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability_family(100) and current_platform.is_device_capability_family(100)
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment