Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fec2b341
Unverified
Commit
fec2b341
authored
Oct 17, 2025
by
Jee Jee Li
Committed by
GitHub
Oct 17, 2025
Browse files
[Kernel] Lazy import FlashInfer (#26977)
parent
87bc0c49
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
38 deletions
+25
-38
tests/v1/sample/test_topk_topp_sampler.py
tests/v1/sample/test_topk_topp_sampler.py
+9
-8
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+16
-30
No files found.
tests/v1/sample/test_topk_topp_sampler.py
View file @
fec2b341
...
@@ -5,20 +5,13 @@ import torch
...
@@ -5,20 +5,13 @@ import torch
from
torch
import
Generator
from
torch
import
Generator
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
apply_top_k_top_p
,
is_flashinfer_available
,
)
DEVICE
=
current_platform
.
device_type
DEVICE
=
current_platform
.
device_type
BATCH_SIZE
=
1024
BATCH_SIZE
=
1024
VOCAB_SIZE
=
128
*
1024
VOCAB_SIZE
=
128
*
1024
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
if
is_flashinfer_available
:
from
flashinfer.sampling
import
top_k_renorm_probs
,
top_p_renorm_probs
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_default_device
():
def
reset_default_device
():
...
@@ -65,6 +58,14 @@ def test_flashinfer_sampler():
...
@@ -65,6 +58,14 @@ def test_flashinfer_sampler():
sampling results due to randomness), so we will compare the probability
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
renormed consequently by top-k and then top-p of FlashInfer implementation.
"""
"""
try
:
from
flashinfer.sampling
import
top_k_renorm_probs
,
top_p_renorm_probs
is_flashinfer_available
=
True
except
ImportError
:
is_flashinfer_available
=
False
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
if
not
FLASHINFER_ENABLED
:
if
not
FLASHINFER_ENABLED
:
pytest
.
skip
(
"FlashInfer not installed or not available on this platform."
)
pytest
.
skip
(
"FlashInfer not installed or not available on this platform."
)
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
fec2b341
...
@@ -13,13 +13,6 @@ from vllm.platforms import CpuArchEnum, current_platform
...
@@ -13,13 +13,6 @@ from vllm.platforms import CpuArchEnum, current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
try
:
import
flashinfer.sampling
is_flashinfer_available
=
True
except
ImportError
:
is_flashinfer_available
=
False
class
TopKTopPSampler
(
nn
.
Module
):
class
TopKTopPSampler
(
nn
.
Module
):
"""
"""
...
@@ -38,15 +31,7 @@ class TopKTopPSampler(nn.Module):
...
@@ -38,15 +31,7 @@ class TopKTopPSampler(nn.Module):
logprobs_mode
not
in
(
"processed_logits"
,
"processed_logprobs"
)
logprobs_mode
not
in
(
"processed_logits"
,
"processed_logprobs"
)
and
current_platform
.
is_cuda
()
and
current_platform
.
is_cuda
()
):
):
if
is_flashinfer_available
:
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
:
flashinfer_version
=
flashinfer
.
__version__
if
version
.
parse
(
flashinfer_version
)
<
version
.
parse
(
"0.2.3"
):
logger
.
warning_once
(
"FlashInfer version >= 0.2.3 required. "
"Falling back to default sampling implementation."
)
self
.
forward
=
self
.
forward_native
elif
envs
.
VLLM_USE_FLASHINFER_SAMPLER
:
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
logger
.
info_once
(
"Using FlashInfer for top-p & top-k sampling."
)
logger
.
info_once
(
"Using FlashInfer for top-p & top-k sampling."
)
self
.
forward
=
self
.
forward_cuda
self
.
forward
=
self
.
forward_cuda
...
@@ -57,13 +42,7 @@ class TopKTopPSampler(nn.Module):
...
@@ -57,13 +42,7 @@ class TopKTopPSampler(nn.Module):
"after verifying accuracy for your workloads."
"after verifying accuracy for your workloads."
)
)
self
.
forward
=
self
.
forward_native
self
.
forward
=
self
.
forward_native
else
:
logger
.
warning_once
(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer."
)
self
.
forward
=
self
.
forward_native
elif
current_platform
.
is_cpu
():
elif
current_platform
.
is_cpu
():
arch
=
current_platform
.
get_cpu_architecture
()
arch
=
current_platform
.
get_cpu_architecture
()
# Fall back to native implementation for POWERPC and RISCV.
# Fall back to native implementation for POWERPC and RISCV.
...
@@ -278,6 +257,13 @@ def flashinfer_sample(
...
@@ -278,6 +257,13 @@ def flashinfer_sample(
does not. Call this function at the end of the forward pass to minimize
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
the synchronization overhead.
"""
"""
import
flashinfer
if
version
.
parse
(
flashinfer
.
__version__
)
<
version
.
parse
(
"0.2.3"
):
raise
ImportError
(
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
)
assert
not
(
k
is
None
and
p
is
None
)
assert
not
(
k
is
None
and
p
is
None
)
if
k
is
None
:
if
k
is
None
:
# Top-p only.
# Top-p only.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment