Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fec2b341
Unverified
Commit
fec2b341
authored
Oct 17, 2025
by
Jee Jee Li
Committed by
GitHub
Oct 17, 2025
Browse files
[Kernel] Lazy import FlashInfer (#26977)
parent
87bc0c49
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
38 deletions
+25
-38
tests/v1/sample/test_topk_topp_sampler.py
tests/v1/sample/test_topk_topp_sampler.py
+9
-8
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+16
-30
No files found.
tests/v1/sample/test_topk_topp_sampler.py
View file @
fec2b341
...
...
@@ -5,20 +5,13 @@ import torch
from
torch
import
Generator
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
apply_top_k_top_p
,
is_flashinfer_available
,
)
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
DEVICE
=
current_platform
.
device_type
BATCH_SIZE
=
1024
VOCAB_SIZE
=
128
*
1024
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
if
is_flashinfer_available
:
from
flashinfer.sampling
import
top_k_renorm_probs
,
top_p_renorm_probs
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_default_device
():
...
...
@@ -65,6 +58,14 @@ def test_flashinfer_sampler():
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
"""
try
:
from
flashinfer.sampling
import
top_k_renorm_probs
,
top_p_renorm_probs
is_flashinfer_available
=
True
except
ImportError
:
is_flashinfer_available
=
False
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
if
not
FLASHINFER_ENABLED
:
pytest
.
skip
(
"FlashInfer not installed or not available on this platform."
)
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
fec2b341
...
...
@@ -13,13 +13,6 @@ from vllm.platforms import CpuArchEnum, current_platform
logger
=
init_logger
(
__name__
)
try
:
import
flashinfer.sampling
is_flashinfer_available
=
True
except
ImportError
:
is_flashinfer_available
=
False
class
TopKTopPSampler
(
nn
.
Module
):
"""
...
...
@@ -38,32 +31,18 @@ class TopKTopPSampler(nn.Module):
logprobs_mode
not
in
(
"processed_logits"
,
"processed_logprobs"
)
and
current_platform
.
is_cuda
()
):
if
is_flashinfer_available
:
flashinfer_version
=
flashinfer
.
__version__
if
version
.
parse
(
flashinfer_version
)
<
version
.
parse
(
"0.2.3"
):
logger
.
warning_once
(
"FlashInfer version >= 0.2.3 required. "
"Falling back to default sampling implementation."
)
self
.
forward
=
self
.
forward_native
elif
envs
.
VLLM_USE_FLASHINFER_SAMPLER
:
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
logger
.
info_once
(
"Using FlashInfer for top-p & top-k sampling."
)
self
.
forward
=
self
.
forward_cuda
else
:
logger
.
debug_once
(
"FlashInfer top-p/top-k sampling is available but disabled "
"by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
"after verifying accuracy for your workloads."
)
self
.
forward
=
self
.
forward_native
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
:
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
logger
.
info_once
(
"Using FlashInfer for top-p & top-k sampling."
)
self
.
forward
=
self
.
forward_cuda
else
:
logger
.
warnin
g_once
(
"FlashInfer
is not available. Falling back to the PyTorch-
"
"
native implementation of top-p & top-k sampling. For the
"
"
best performance, please install FlashInfer
."
logger
.
debu
g_once
(
"FlashInfer
top-p/top-k sampling is available but disabled
"
"
by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in
"
"
after verifying accuracy for your workloads
."
)
self
.
forward
=
self
.
forward_native
elif
current_platform
.
is_cpu
():
arch
=
current_platform
.
get_cpu_architecture
()
# Fall back to native implementation for POWERPC and RISCV.
...
...
@@ -278,6 +257,13 @@ def flashinfer_sample(
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
import
flashinfer
if
version
.
parse
(
flashinfer
.
__version__
)
<
version
.
parse
(
"0.2.3"
):
raise
ImportError
(
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
)
assert
not
(
k
is
None
and
p
is
None
)
if
k
is
None
:
# Top-p only.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment