Unverified Commit 9eec282c authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Guard FlashInfer sampler using the same check as FlashInfer attention backend (#29415)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 0808eb81
......@@ -33,6 +33,16 @@ class TopKTopPSampler(nn.Module):
and current_platform.is_cuda()
):
if envs.VLLM_USE_FLASHINFER_SAMPLER:
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
capability = current_platform.get_device_capability()
assert capability is not None
if not FlashInferBackend.supports_compute_capability(capability):
capability_str = capability.as_version_str()
raise RuntimeError(
"FlashInfer does not support compute capability "
f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1."
)
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment