Unverified Commit fec2b341 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Kernel] Lazy import FlashInfer (#26977)

parent 87bc0c49
......@@ -5,20 +5,13 @@ import torch
from torch import Generator
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (
apply_top_k_top_p,
is_flashinfer_available,
)
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
DEVICE = current_platform.device_type
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
if is_flashinfer_available:
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
@pytest.fixture(autouse=True)
def reset_default_device():
......@@ -65,6 +58,14 @@ def test_flashinfer_sampler():
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
"""
try:
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
if not FLASHINFER_ENABLED:
pytest.skip("FlashInfer not installed or not available on this platform.")
......
......@@ -13,13 +13,6 @@ from vllm.platforms import CpuArchEnum, current_platform
logger = init_logger(__name__)
try:
import flashinfer.sampling
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
class TopKTopPSampler(nn.Module):
"""
......@@ -38,15 +31,7 @@ class TopKTopPSampler(nn.Module):
logprobs_mode not in ("processed_logits", "processed_logprobs")
and current_platform.is_cuda()
):
if is_flashinfer_available:
flashinfer_version = flashinfer.__version__
if version.parse(flashinfer_version) < version.parse("0.2.3"):
logger.warning_once(
"FlashInfer version >= 0.2.3 required. "
"Falling back to default sampling implementation."
)
self.forward = self.forward_native
elif envs.VLLM_USE_FLASHINFER_SAMPLER:
if envs.VLLM_USE_FLASHINFER_SAMPLER:
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
logger.info_once("Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda
......@@ -57,13 +42,7 @@ class TopKTopPSampler(nn.Module):
"after verifying accuracy for your workloads."
)
self.forward = self.forward_native
else:
logger.warning_once(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer."
)
self.forward = self.forward_native
elif current_platform.is_cpu():
arch = current_platform.get_cpu_architecture()
# Fall back to native implementation for POWERPC and RISCV.
......@@ -278,6 +257,13 @@ def flashinfer_sample(
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
import flashinfer
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
raise ImportError(
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
)
assert not (k is None and p is None)
if k is None:
# Top-p only.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment