Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a6be75db
Unverified
Commit
a6be75db
authored
Mar 08, 2026
by
PatchyTIS
Committed by
GitHub
Mar 07, 2026
Browse files
[Core] NGram GPU Implementation compatible with Async Scheduler (#29184)
parent
ee54f9cd
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
940 additions
and
12 deletions
+940
-12
tests/v1/e2e/test_async_scheduling.py
tests/v1/e2e/test_async_scheduling.py
+40
-3
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+28
-0
vllm/compilation/backends.py
vllm/compilation/backends.py
+7
-0
vllm/config/speculative.py
vllm/config/speculative.py
+9
-1
vllm/config/vllm.py
vllm/config/vllm.py
+5
-2
vllm/tool_parsers/hermes_tool_parser.py
vllm/tool_parsers/hermes_tool_parser.py
+2
-0
vllm/v1/spec_decode/ngram_proposer_gpu.py
vllm/v1/spec_decode/ngram_proposer_gpu.py
+660
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+7
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+182
-5
No files found.
tests/v1/e2e/test_async_scheduling.py
View file @
a6be75db
...
@@ -98,7 +98,7 @@ def test_without_spec_decoding(
...
@@ -98,7 +98,7 @@ def test_without_spec_decoding(
@
single_gpu_only
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
16
)
@
large_gpu_mark
(
min_gb
=
16
)
def
test_with_spec_decoding
(
sample_json_schema
,
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_with_
eagle3_
spec_decoding
(
sample_json_schema
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test consistency and acceptance rates with some different combos of
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
spec decoding model length.
...
@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
...
@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
)
)
def
test_with_ngram_gpu_spec_decoding
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test ngram_gpu speculative decoding with different configurations.
This test specifically validates ngram_gpu behavior with various:
- Number of speculative tokens (2-6)
- Prompt lookup window sizes (min/max)
- Async scheduling enabled (as in production)
- Different executors and chunking settings
"""
# Variant with larger speculation window
ngram_gpu_config
=
{
"method"
:
"ngram_gpu"
,
"num_speculative_tokens"
:
3
,
"prompt_lookup_max"
:
3
,
"prompt_lookup_min"
:
2
,
}
# Test configurations covering various scenarios
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs
=
[
(
False
,
"mp"
,
False
,
None
,
False
),
(
False
,
"mp"
,
False
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
False
,
ngram_gpu_config
,
True
),
(
False
,
"mp"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"uni"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
True
,
ngram_gpu_config
,
True
),
]
# Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
# and ngram_gpu doesn't require a specific draft model
run_tests
(
monkeypatch
,
MODEL
,
test_configs
,
[{}])
@
dynamo_config
.
patch
(
cache_size_limit
=
16
)
@
dynamo_config
.
patch
(
cache_size_limit
=
16
)
def
run_tests
(
def
run_tests
(
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
...
@@ -282,11 +318,12 @@ def run_test(
...
@@ -282,11 +318,12 @@ def run_test(
else
dict
(
gpu_memory_utilization
=
0.9
)
else
dict
(
gpu_memory_utilization
=
0.9
)
)
)
spec_mml
=
(
spec_config
or
{}).
get
(
"max_model_len"
)
spec_mml
=
(
spec_config
or
{}).
get
(
"max_model_len"
)
spec_method
=
(
spec_config
or
{}).
get
(
"method"
,
"none"
)
test_config
=
(
test_config
=
(
f
"executor=
{
executor
}
, preemption=
{
test_preemption
}
, "
f
"executor=
{
executor
}
, preemption=
{
test_preemption
}
, "
f
"async_sched=
{
async_scheduling
}
, "
f
"async_sched=
{
async_scheduling
}
, "
f
"chunk_prefill=
{
test_prefill_chunking
}
, "
f
"chunk_prefill=
{
test_prefill_chunking
}
, "
f
"spec_decoding=
{
spec_decoding
}
, spec_mml=
{
spec_mml
}
"
f
"spec_decoding=
{
spec_decoding
}
,
spec_method=
{
spec_method
}
,
spec_mml=
{
spec_mml
}
"
)
)
print
(
"-"
*
80
)
print
(
"-"
*
80
)
print
(
f
"---- TESTING
{
test_str
}
:
{
test_config
}
"
)
print
(
f
"---- TESTING
{
test_str
}
:
{
test_config
}
"
)
...
@@ -294,7 +331,7 @@ def run_test(
...
@@ -294,7 +331,7 @@ def run_test(
with
VllmRunner
(
with
VllmRunner
(
model
,
model
,
max_model_len
=
512
,
max_model_len
=
4096
,
enable_chunked_prefill
=
test_prefill_chunking
,
enable_chunked_prefill
=
test_prefill_chunking
,
# Force prefill chunking
# Force prefill chunking
max_num_batched_tokens
=
48
if
test_prefill_chunking
else
None
,
max_num_batched_tokens
=
48
if
test_prefill_chunking
else
None
,
...
...
tests/v1/e2e/test_spec_decode.py
View file @
a6be75db
...
@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness(
...
@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness(
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
],
ids
=
[
"async"
])
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
20
)
def
test_ngram_gpu_default_with_async_scheduling
(
async_scheduling
:
bool
,
):
"""
Test ngram_gpu speculative decoding (k=3) correctness with and without
async scheduling, validated via GSM8K accuracy.
Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%).
"""
qwen3_model
=
"Qwen/Qwen3-8B"
spec_llm
=
LLM
(
model
=
qwen3_model
,
speculative_config
=
{
"method"
:
"ngram_gpu"
,
"prompt_lookup_max"
:
3
,
"prompt_lookup_min"
:
2
,
"num_speculative_tokens"
:
2
,
},
max_model_len
=
4096
,
async_scheduling
=
async_scheduling
,
)
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
0.8
)
del
spec_llm
cleanup_dist_env_and_memory
()
@
single_gpu_only
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
20
)
@
large_gpu_mark
(
min_gb
=
20
)
def
test_suffix_decoding_acceptance
(
def
test_suffix_decoding_acceptance
(
...
...
vllm/compilation/backends.py
View file @
a6be75db
...
@@ -907,6 +907,13 @@ class VllmBackend:
...
@@ -907,6 +907,13 @@ class VllmBackend:
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache
=
not
is_compile_cache_enabled
(
self
.
inductor_config
)
disable_cache
=
not
is_compile_cache_enabled
(
self
.
inductor_config
)
# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
is_ngram_gpu_enabled
=
(
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
speculative_config
.
use_ngram_gpu
()
)
disable_cache
=
disable_cache
or
is_ngram_gpu_enabled
if
disable_cache
:
if
disable_cache
:
logger
.
info_once
(
"vLLM's torch.compile cache is disabled."
,
scope
=
"local"
)
logger
.
info_once
(
"vLLM's torch.compile cache is disabled."
,
scope
=
"local"
)
else
:
else
:
...
...
vllm/config/speculative.py
View file @
a6be75db
...
@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
...
@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
"step3p5_mtp"
,
"step3p5_mtp"
,
]
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
]
NgramGPUTypes
=
Literal
[
"ngram_gpu"
]
SpeculativeMethod
=
Literal
[
SpeculativeMethod
=
Literal
[
"ngram"
,
"ngram"
,
"medusa"
,
"medusa"
,
...
@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[
...
@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[
"draft_model"
,
"draft_model"
,
"suffix"
,
"suffix"
,
EagleModelTypes
,
EagleModelTypes
,
NgramGPUTypes
,
]
]
...
@@ -364,6 +366,8 @@ class SpeculativeConfig:
...
@@ -364,6 +366,8 @@ class SpeculativeConfig:
self
.
quantization
=
self
.
target_model_config
.
quantization
self
.
quantization
=
self
.
target_model_config
.
quantization
elif
self
.
method
in
(
"ngram"
,
"[ngram]"
):
elif
self
.
method
in
(
"ngram"
,
"[ngram]"
):
self
.
model
=
"ngram"
self
.
model
=
"ngram"
elif
self
.
method
==
"ngram_gpu"
:
self
.
model
=
"ngram_gpu"
elif
self
.
method
==
"suffix"
:
elif
self
.
method
==
"suffix"
:
self
.
model
=
"suffix"
self
.
model
=
"suffix"
elif
self
.
method
==
"extract_hidden_states"
:
elif
self
.
method
==
"extract_hidden_states"
:
...
@@ -374,8 +378,9 @@ class SpeculativeConfig:
...
@@ -374,8 +378,9 @@ class SpeculativeConfig:
)
)
if
self
.
method
in
(
"ngram"
,
"[ngram]"
):
if
self
.
method
in
(
"ngram"
,
"[ngram]"
):
# Unified to "ngram" internally
self
.
method
=
"ngram"
self
.
method
=
"ngram"
if
self
.
method
in
(
"ngram"
,
"ngram_gpu"
):
# Set default values if not provided
# Set default values if not provided
if
self
.
prompt_lookup_min
is
None
and
self
.
prompt_lookup_max
is
None
:
if
self
.
prompt_lookup_min
is
None
and
self
.
prompt_lookup_max
is
None
:
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
...
@@ -832,6 +837,9 @@ class SpeculativeConfig:
...
@@ -832,6 +837,9 @@ class SpeculativeConfig:
def
uses_extract_hidden_states
(
self
)
->
bool
:
def
uses_extract_hidden_states
(
self
)
->
bool
:
return
self
.
method
==
"extract_hidden_states"
return
self
.
method
==
"extract_hidden_states"
def
use_ngram_gpu
(
self
)
->
bool
:
return
self
.
method
==
"ngram_gpu"
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
method
=
self
.
method
model
=
(
model
=
(
...
...
vllm/config/vllm.py
View file @
a6be75db
...
@@ -41,7 +41,7 @@ from .offload import OffloadConfig
...
@@ -41,7 +41,7 @@ from .offload import OffloadConfig
from
.parallel
import
ParallelConfig
from
.parallel
import
ParallelConfig
from
.profiler
import
ProfilerConfig
from
.profiler
import
ProfilerConfig
from
.scheduler
import
SchedulerConfig
from
.scheduler
import
SchedulerConfig
from
.speculative
import
EagleModelTypes
,
SpeculativeConfig
from
.speculative
import
EagleModelTypes
,
NgramGPUTypes
,
SpeculativeConfig
from
.structured_outputs
import
StructuredOutputsConfig
from
.structured_outputs
import
StructuredOutputsConfig
from
.utils
import
SupportsHash
,
config
,
replace
from
.utils
import
SupportsHash
,
config
,
replace
from
.weight_transfer
import
WeightTransferConfig
from
.weight_transfer
import
WeightTransferConfig
...
@@ -696,11 +696,13 @@ class VllmConfig:
...
@@ -696,11 +696,13 @@ class VllmConfig:
if
self
.
speculative_config
is
not
None
:
if
self
.
speculative_config
is
not
None
:
if
(
if
(
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
and
self
.
speculative_config
.
method
not
in
get_args
(
NgramGPUTypes
)
and
self
.
speculative_config
.
method
!=
"draft_model"
and
self
.
speculative_config
.
method
!=
"draft_model"
):
):
raise
ValueError
(
raise
ValueError
(
"Currently, async scheduling is only supported "
"Currently, async scheduling is only supported "
"with EAGLE/MTP/Draft Model kind of speculative decoding."
"with EAGLE/MTP/Draft Model/NGram GPU kind of "
"speculative decoding"
)
)
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
raise
ValueError
(
raise
ValueError
(
...
@@ -718,6 +720,7 @@ class VllmConfig:
...
@@ -718,6 +720,7 @@ class VllmConfig:
if
(
if
(
self
.
speculative_config
is
not
None
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
and
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
and
self
.
speculative_config
.
method
not
in
get_args
(
NgramGPUTypes
)
):
):
logger
.
warning_once
(
logger
.
warning_once
(
"Async scheduling not supported with %s-based "
"Async scheduling not supported with %s-based "
...
...
vllm/tool_parsers/hermes_tool_parser.py
View file @
a6be75db
...
@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser):
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
"arguments"
)
)
assert
current_tool_call
is
not
None
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
logger
.
debug
(
"diffing old arguments: %s"
,
prev_arguments
)
logger
.
debug
(
"diffing old arguments: %s"
,
prev_arguments
)
...
@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser):
# handle saving the state for the current tool into
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
# the "prev" list for use in diffing for the next iteration
assert
isinstance
(
current_tool_call
,
dict
)
if
self
.
current_tool_id
==
len
(
self
.
prev_tool_call_arr
)
-
1
:
if
self
.
current_tool_id
==
len
(
self
.
prev_tool_call_arr
)
-
1
:
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
current_tool_call
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
current_tool_call
else
:
else
:
...
...
vllm/v1/spec_decode/ngram_proposer_gpu.py
0 → 100644
View file @
a6be75db
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.
This version uses a fully vectorized approach with unfold and argmax for
finding the first match across all sequences in parallel.
"""
import
torch
from
torch
import
nn
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
VllmConfig
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.utils
import
record_function_or_nullcontext
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
@
support_torch_compile
()
class
NgramGPUKernel
(
nn
.
Module
):
"""GPU-accelerated N-gram proposer using fully async tensor operations."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
device
:
torch
.
device
=
"cuda"
):
super
().
__init__
()
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_min
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_max
is
not
None
self
.
min_n
=
vllm_config
.
speculative_config
.
prompt_lookup_min
self
.
max_n
=
vllm_config
.
speculative_config
.
prompt_lookup_max
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_seqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
device
=
device
def
_find_first_and_extract_all_n_parallel
(
self
,
token_ids
:
torch
.
Tensor
,
seq_lengths
:
torch
.
Tensor
,
min_ngram_len
:
int
,
max_ngram_len
:
int
,
num_draft_tokens
:
int
,
)
->
torch
.
Tensor
:
"""
Find suffix n-gram matches and extract following tokens.
Searches for the earliest prior occurrence of the trailing n-gram,
tries multiple lengths, and picks the longest valid match.
Args:
token_ids: Token IDs for each sequence
seq_lengths: Actual length of each sequence (excluding padding)
min_ngram_len: Minimum n-gram size to search for (e.g., 2)
max_ngram_len: Maximum n-gram size to search for (e.g., 5)
num_draft_tokens: Number of tokens to extract after match (k)
Returns:
Draft token predictions; -1 means invalid/no match.
"""
batch_size
=
token_ids
.
shape
[
0
]
max_seq_len
=
token_ids
.
shape
[
1
]
device
=
token_ids
.
device
num_ngram_sizes
=
max_ngram_len
-
min_ngram_len
+
1
# All n-gram sizes to try.
ngram_lengths
=
torch
.
arange
(
min_ngram_len
,
max_ngram_len
+
1
,
device
=
device
)
batch_indices
=
torch
.
arange
(
batch_size
,
device
=
device
)
# Earliest match per (sequence, ngram_len); -1 means no match.
first_match_positions
=
torch
.
full
(
(
batch_size
,
num_ngram_sizes
),
-
1
,
dtype
=
torch
.
long
,
device
=
device
)
for
i
,
ngram_len
in
enumerate
(
range
(
min_ngram_len
,
max_ngram_len
+
1
)):
# Sliding windows of size ngram_len; unfold is O(1) view.
search_windows
=
token_ids
.
unfold
(
1
,
ngram_len
,
1
)
num_windows
=
search_windows
.
shape
[
1
]
# Trailing suffix (last ngram_len tokens) for each sequence.
suffix_starts
=
seq_lengths
-
ngram_len
suffix_indices
=
suffix_starts
.
unsqueeze
(
1
)
+
torch
.
arange
(
ngram_len
,
device
=
device
)
suffix
=
torch
.
gather
(
token_ids
,
1
,
suffix_indices
.
clamp
(
min
=
0
))
# Window matches for each sequence.
matches
=
(
search_windows
==
suffix
.
unsqueeze
(
1
)).
all
(
dim
=-
1
)
# Match must leave room for at least one draft token.
max_valid_suffix_start
=
seq_lengths
-
ngram_len
-
1
window_positions
=
torch
.
arange
(
num_windows
,
device
=
device
)
valid_mask
=
window_positions
<=
max_valid_suffix_start
.
unsqueeze
(
1
)
final_matches
=
matches
&
valid_mask
# Find earliest match (argmax=0 when empty; verify with has_match).
first_match_idx
=
torch
.
argmax
(
final_matches
.
int
(),
dim
=
1
)
has_match
=
final_matches
[
batch_indices
,
first_match_idx
]
# Store valid match positions (window index = position).
first_match_positions
[:,
i
]
=
torch
.
where
(
has_match
,
first_match_idx
,
-
1
)
# Select the longest n-gram with a match.
best_ngram_idx
=
(
first_match_positions
>=
0
).
int
().
flip
(
dims
=
[
1
]).
argmax
(
dim
=
1
)
best_ngram_idx
=
num_ngram_sizes
-
1
-
best_ngram_idx
# Flip back
# Match position for the best n-gram.
best_match_pos
=
first_match_positions
[
batch_indices
,
best_ngram_idx
]
# Avoid data-dependent branching.
has_any_match
=
best_match_pos
>=
0
# Length of the best matching n-gram.
best_ngram_lengths
=
ngram_lengths
[
best_ngram_idx
]
# Start position right after the matched suffix.
draft_start
=
torch
.
where
(
has_any_match
,
best_match_pos
+
best_ngram_lengths
,
torch
.
zeros_like
(
best_match_pos
),
)
tokens_available
=
seq_lengths
-
draft_start
# Gather indices for draft tokens.
draft_indices
=
draft_start
.
unsqueeze
(
1
)
+
torch
.
arange
(
num_draft_tokens
,
device
=
device
)
draft_indices
=
draft_indices
.
clamp
(
min
=
0
,
max
=
max_seq_len
-
1
)
# Extract draft tokens; gather always runs.
draft_tokens
=
torch
.
gather
(
token_ids
,
1
,
draft_indices
)
# Mask positions beyond available tokens.
position_indices
=
torch
.
arange
(
num_draft_tokens
,
device
=
device
).
unsqueeze
(
0
)
valid_positions
=
position_indices
<
tokens_available
.
unsqueeze
(
1
)
draft_tokens
=
torch
.
where
(
valid_positions
,
draft_tokens
,
torch
.
full_like
(
draft_tokens
,
-
1
),
)
# If no match, mask all positions.
draft_tokens
=
torch
.
where
(
has_any_match
.
unsqueeze
(
1
),
draft_tokens
,
torch
.
full_like
(
draft_tokens
,
-
1
),
)
return
draft_tokens
def
forward
(
self
,
num_tokens_no_spec
:
torch
.
Tensor
,
token_ids_gpu
:
torch
.
Tensor
,
combined_mask
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Forward pass for N-gram proposal using GPU tensor operations.
Args:
num_tokens_no_spec: Number of tokens for each sequence [batch_size]
token_ids_gpu: Token IDs [batch_size, max_len]
combined_mask: Whether each sequence is valid for spec decode [batch_size]
Returns:
draft_tokens: [batch_size, k] on GPU
num_valid_draft_tokens: [batch_size] int32 on GPU, count of
leading valid (non -1) tokens per request.
"""
device
=
token_ids_gpu
.
device
# Infer batch size to preserve dynamic shape.
actual_batch_size
=
token_ids_gpu
.
shape
[
0
]
# Allocate in forward so torch.compile can optimize.
# NOTE(patchy): Do NOT pre-allocate this as a buffer
# it breaks torch.compile
draft_tokens
=
torch
.
full
(
(
actual_batch_size
,
self
.
k
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
results
=
self
.
_find_first_and_extract_all_n_parallel
(
token_ids_gpu
,
num_tokens_no_spec
,
min_ngram_len
=
self
.
min_n
,
max_ngram_len
=
self
.
max_n
,
num_draft_tokens
=
self
.
k
,
)
draft_tokens
=
torch
.
where
(
combined_mask
.
unsqueeze
(
1
),
results
,
-
1
)
# Count leading contiguous valid (non -1) tokens per request.
is_valid
=
draft_tokens
!=
-
1
# [batch, k]
cum_valid
=
is_valid
.
int
().
cumsum
(
dim
=
1
)
# [batch, k]
positions
=
torch
.
arange
(
1
,
self
.
k
+
1
,
device
=
device
).
unsqueeze
(
0
)
num_valid_draft_tokens
=
(
cum_valid
==
positions
).
int
().
sum
(
dim
=
1
)
return
draft_tokens
,
num_valid_draft_tokens
def
load_model
(
self
,
*
args
,
**
kwargs
):
"""No model to load for N-gram proposer."""
pass
class
NgramProposerGPU
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
):
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_min
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_max
is
not
None
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
[
"none"
],
splitting_ops
=
[],
compile_sizes
=
[],
inductor_compile_config
=
{
"enable_auto_functionalized_v2"
:
False
,
"max_autotune"
:
True
,
"aggressive_fusion"
:
True
,
"triton.autotune_pointwise"
:
True
,
"coordinate_descent_tuning"
:
True
,
"use_mixed_mm"
:
False
,
},
cudagraph_mode
=
CUDAGraphMode
.
NONE
,
)
model_config
=
vllm_config
.
model_config
speculative_config
=
vllm_config
.
speculative_config
scheduler_config
=
vllm_config
.
scheduler_config
self
.
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
,
model_config
=
model_config
,
speculative_config
=
speculative_config
,
scheduler_config
=
scheduler_config
,
)
self
.
min_n
=
vllm_config
.
speculative_config
.
prompt_lookup_min
self
.
max_n
=
vllm_config
.
speculative_config
.
prompt_lookup_max
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_seqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
device
=
device
self
.
kernel
=
NgramGPUKernel
(
vllm_config
=
self
.
vllm_config
,
prefix
=
"ngram_gpu_kernel"
,
device
=
device
)
self
.
kernel
.
to
(
device
)
self
.
kernel
.
eval
()
self
.
_dummy_run
()
def
_dummy_run
(
self
):
token_ids
,
num_tokens
,
sampled_flags
,
valid_mask
=
self
.
_generate_dummy_data
(
batch_size
=
self
.
max_num_seqs
,
max_seq_len
=
self
.
max_model_len
,
pattern_len
=
self
.
k
,
device
=
self
.
device
,
)
combined_mask
=
sampled_flags
&
valid_mask
&
(
num_tokens
>=
self
.
min_n
)
for
_
in
range
(
3
):
with
set_forward_context
(
None
,
self
.
vllm_config
):
_
,
_
=
self
.
kernel
(
num_tokens
,
token_ids
,
combined_mask
)
def
_generate_dummy_data
(
self
,
batch_size
:
int
,
max_seq_len
:
int
,
pattern_len
:
int
,
device
:
str
=
"cuda"
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Generate random test data with n-gram repetitions.
Args:
batch_size: Number of sequences in the batch
max_seq_len: Maximum sequence length
pattern_len: Length of patterns to inject for matching
device: Device to place tensors on
Returns:
token_ids: [batch_size, max_seq_len] tensor
num_tokens: [batch_size] tensor
sampled_flags: [batch_size] bool tensor
valid_mask: [batch_size] bool tensor
"""
token_ids
=
torch
.
zeros
(
batch_size
,
max_seq_len
,
dtype
=
torch
.
int32
,
device
=
device
,
)
num_tokens
=
torch
.
randint
(
pattern_len
,
max_seq_len
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
sampled_flags
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
device
)
valid_mask
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
device
)
return
token_ids
,
num_tokens
,
sampled_flags
,
valid_mask
def
propose
(
self
,
num_tokens_no_spec
:
torch
.
Tensor
,
# [batch_size]
token_ids_gpu
:
torch
.
Tensor
,
# [batch_size, max_len]
valid_sampled_token_ids_gpu
:
torch
.
Tensor
,
# [batch_size, num_spec_tokens + 1]
valid_sampled_tokens_count
:
torch
.
Tensor
,
# [batch_size]
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Propose draft tokens using GPU-accelerated n-gram matching.
Scatter sampled tokens into `token_ids_gpu`, compute temporary
updated lengths, then run the kernel.
Args:
num_tokens_no_spec: Number of tokens per sequence (read-only)
token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
valid_sampled_tokens_count: Count of valid tokens per sequence
Returns:
draft_tokens: Proposed draft token IDs [batch_size, k]
num_valid_draft_tokens: Count of leading valid draft tokens
per request [batch_size]
"""
assert
token_ids_gpu
.
device
==
self
.
device
assert
num_tokens_no_spec
.
device
==
self
.
device
batch_size
=
num_tokens_no_spec
.
shape
[
0
]
max_seq_len
=
token_ids_gpu
.
shape
[
1
]
max_new_tokens
=
valid_sampled_token_ids_gpu
.
shape
[
1
]
# num_spec_tokens + 1
# Scatter newly sampled tokens into token_ids_gpu.
offsets
=
torch
.
arange
(
max_new_tokens
,
device
=
self
.
device
)
write_positions
=
num_tokens_no_spec
.
unsqueeze
(
1
)
+
offsets
.
unsqueeze
(
0
)
valid_write_mask
=
offsets
.
unsqueeze
(
0
)
<
valid_sampled_tokens_count
.
unsqueeze
(
1
)
in_bounds
=
write_positions
<
max_seq_len
scatter_mask
=
(
valid_write_mask
&
(
valid_sampled_token_ids_gpu
!=
-
1
)
&
in_bounds
)
write_positions_long
=
write_positions
.
clamp
(
max
=
max_seq_len
-
1
).
long
()
existing_values
=
token_ids_gpu
.
gather
(
1
,
write_positions_long
)
tokens_cast
=
valid_sampled_token_ids_gpu
.
to
(
token_ids_gpu
.
dtype
)
tokens_to_scatter
=
torch
.
where
(
scatter_mask
,
tokens_cast
,
existing_values
,
)
token_ids_gpu
.
scatter_
(
1
,
write_positions_long
,
tokens_to_scatter
)
num_tokens_tmp
=
num_tokens_no_spec
+
valid_sampled_tokens_count
# Compute validity masks.
sampled_flags
=
valid_sampled_tokens_count
>
0
valid_mask
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
)
with
set_forward_context
(
None
,
self
.
vllm_config
):
combined_mask
=
sampled_flags
&
valid_mask
&
(
num_tokens_tmp
>=
self
.
min_n
)
with
record_function_or_nullcontext
(
"ngram_proposer_gpu: kernel"
):
draft_tokens
,
num_valid_draft_tokens
=
self
.
kernel
(
num_tokens_tmp
,
token_ids_gpu
,
combined_mask
,
)
return
draft_tokens
,
num_valid_draft_tokens
def
update_token_ids_ngram
(
self
,
sampled_token_ids
:
torch
.
Tensor
|
list
[
list
[
int
]],
gpu_input_batch
:
InputBatch
,
token_ids_gpu
:
torch
.
Tensor
,
num_tokens_no_spec
:
torch
.
Tensor
,
discard_request_mask
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Prepare speculative decoding inputs on device:
compute next token ids and valid counts, honoring discarded requests
and rejected tokens, without CPU-GPU sync.
"""
num_reqs
=
gpu_input_batch
.
num_reqs
if
isinstance
(
sampled_token_ids
,
list
):
# When disable_padded_drafter_batch=True, sampled_token_ids is
# an irregular list[list[int]] where sublists may have different
# lengths (including empty lists for discarded requests).
# Pad all sublists to the same length with -1 before converting
# to tensor.
max_len
=
max
(
(
len
(
sublist
)
for
sublist
in
sampled_token_ids
),
default
=
0
,
)
# Ensure at least length 1 for tensor creation
max_len
=
max
(
max_len
,
1
)
padded_list
=
[
sublist
+
[
-
1
]
*
(
max_len
-
len
(
sublist
))
for
sublist
in
sampled_token_ids
]
sampled_token_ids
=
torch
.
tensor
(
padded_list
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
assert
isinstance
(
sampled_token_ids
,
torch
.
Tensor
),
(
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
)
# Backup last valid token before speculative tokens.
backup_indices
=
(
num_tokens_no_spec
[:
num_reqs
]
-
1
).
clamp
(
min
=
0
).
long
()
backup_next_token_ids
=
torch
.
gather
(
token_ids_gpu
[:
num_reqs
],
dim
=
1
,
index
=
backup_indices
.
unsqueeze
(
1
)
).
squeeze
(
1
)
valid_sampled_token_ids_gpu
=
sampled_token_ids
.
clone
()
# Invalidate sampled tokens for discarded requests.
discard_mask_expanded
=
discard_request_mask
[:
num_reqs
].
unsqueeze
(
1
)
valid_sampled_token_ids_gpu
.
masked_fill_
(
discard_mask_expanded
,
-
1
)
# Mask valid tokens within each request.
valid_mask
=
(
valid_sampled_token_ids_gpu
!=
-
1
)
&
(
valid_sampled_token_ids_gpu
<
gpu_input_batch
.
vocab_size
)
# Count valid tokens per request.
valid_sampled_tokens_count
=
valid_mask
.
sum
(
dim
=
1
)
# Rightmost valid index per row.
last_valid_indices
=
valid_sampled_tokens_count
-
1
last_valid_indices_safe
=
torch
.
clamp
(
last_valid_indices
,
min
=
0
)
# Last valid token from each row; undefined if none.
selected_tokens
=
torch
.
gather
(
valid_sampled_token_ids_gpu
,
1
,
last_valid_indices_safe
.
unsqueeze
(
1
)
).
squeeze
(
1
)
# Use last token if valid; otherwise fallback to backup.
next_token_ids
=
torch
.
where
(
last_valid_indices
!=
-
1
,
selected_tokens
,
backup_next_token_ids
,
)
return
next_token_ids
,
valid_sampled_tokens_count
,
valid_sampled_token_ids_gpu
def
load_model
(
self
,
*
args
,
**
kwargs
):
self
.
kernel
.
load_model
(
*
args
,
**
kwargs
)
def
update_scheduler_for_invalid_drafts
(
num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
,
num_valid_draft_tokens_cpu
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
,
req_id_to_index
:
dict
[
str
,
int
],
)
->
None
:
"""Trim invalid speculative slots using per-request valid draft counts.
Args:
num_valid_draft_tokens_event: Event for async D2H completion.
num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
scheduler_output: Scheduler metadata to update in-place.
req_id_to_index: Request-id to batch-index mapping.
"""
req_data
=
scheduler_output
.
scheduled_cached_reqs
num_valid_draft_tokens_event
.
synchronize
()
for
req_id
in
req_data
.
req_ids
:
req_index
=
req_id_to_index
.
get
(
req_id
)
if
req_index
is
None
:
continue
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
)
if
spec_token_ids
is
None
:
continue
scheduled_k
=
len
(
spec_token_ids
)
valid_k
=
int
(
num_valid_draft_tokens_cpu
[
req_index
].
item
())
valid_k
=
max
(
0
,
min
(
valid_k
,
scheduled_k
))
tokens_to_trim
=
scheduled_k
-
valid_k
scheduler_output
.
total_num_scheduled_tokens
-=
tokens_to_trim
scheduler_output
.
num_scheduled_tokens
[
req_id
]
-=
tokens_to_trim
if
valid_k
==
0
:
scheduler_output
.
scheduled_spec_decode_tokens
.
pop
(
req_id
,
None
)
else
:
scheduler_output
.
scheduled_spec_decode_tokens
[
req_id
]
=
spec_token_ids
[
:
valid_k
]
def
update_ngram_gpu_tensors_incremental
(
input_batch
:
InputBatch
,
token_ids_gpu_tensor
:
torch
.
Tensor
,
num_tokens_no_spec_gpu
:
torch
.
Tensor
,
new_reqs
:
list
[
CachedRequestState
],
device
:
torch
.
device
,
_pinned_idx_buf
:
torch
.
Tensor
,
_pinned_val_buf
:
torch
.
Tensor
,
)
->
None
:
"""Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
for ngram GPU proposer.
"""
prev_req_id_to_index
=
input_batch
.
prev_req_id_to_index
curr_req_id_to_index
=
input_batch
.
req_id_to_index
if
not
curr_req_id_to_index
:
return
active_indices
=
list
(
curr_req_id_to_index
.
values
())
n_active
=
len
(
active_indices
)
# Use resident pinned buffers to avoid per-call allocation.
active_idx_cpu
=
_pinned_idx_buf
[:
n_active
]
active_idx_cpu
.
copy_
(
torch
.
as_tensor
(
active_indices
,
dtype
=
torch
.
long
))
active_idx_gpu
=
active_idx_cpu
.
to
(
device
=
device
,
non_blocking
=
True
)
new_req_ids
=
{
req
.
req_id
for
req
in
new_reqs
}
# First run, no previous state.
if
prev_req_id_to_index
is
None
:
for
idx
in
active_indices
:
num_tokens
=
input_batch
.
num_tokens_no_spec
[
idx
]
if
num_tokens
>
0
:
token_ids_gpu_tensor
[
idx
,
:
num_tokens
].
copy_
(
input_batch
.
token_ids_cpu_tensor
[
idx
,
:
num_tokens
],
non_blocking
=
True
,
)
_sync_num_tokens
(
input_batch
,
num_tokens_no_spec_gpu
,
active_idx_cpu
,
active_idx_gpu
,
n_active
,
device
,
_pinned_val_buf
,
)
return
# Detect index changes for reorder.
reorder_src
:
list
[
int
]
=
[]
reorder_dst
:
list
[
int
]
=
[]
for
req_id
,
curr_idx
in
curr_req_id_to_index
.
items
():
if
req_id
in
new_req_ids
:
continue
prev_idx
=
prev_req_id_to_index
.
get
(
req_id
)
if
prev_idx
is
not
None
and
prev_idx
!=
curr_idx
:
reorder_src
.
append
(
prev_idx
)
reorder_dst
.
append
(
curr_idx
)
if
reorder_src
:
src_tensor
=
torch
.
tensor
(
reorder_src
,
dtype
=
torch
.
long
,
device
=
device
)
dst_tensor
=
torch
.
tensor
(
reorder_dst
,
dtype
=
torch
.
long
,
device
=
device
)
temp_token_ids
=
token_ids_gpu_tensor
[
src_tensor
].
clone
()
temp_num_tokens
=
num_tokens_no_spec_gpu
[
src_tensor
].
clone
()
token_ids_gpu_tensor
[
dst_tensor
]
=
temp_token_ids
num_tokens_no_spec_gpu
[
dst_tensor
]
=
temp_num_tokens
# Full copy for new/resumed requests.
for
req_state
in
new_reqs
:
new_req_idx
=
curr_req_id_to_index
.
get
(
req_state
.
req_id
)
if
new_req_idx
is
None
:
continue
num_tokens
=
input_batch
.
num_tokens_no_spec
[
new_req_idx
]
if
num_tokens
>
0
:
token_ids_gpu_tensor
[
new_req_idx
,
:
num_tokens
].
copy_
(
input_batch
.
token_ids_cpu_tensor
[
new_req_idx
,
:
num_tokens
],
non_blocking
=
True
,
)
# Always batch-sync sequence lengths from CPU for ALL active requests.
_sync_num_tokens
(
input_batch
,
num_tokens_no_spec_gpu
,
active_idx_cpu
,
active_idx_gpu
,
n_active
,
device
,
_pinned_val_buf
,
)
def
_sync_num_tokens
(
input_batch
:
InputBatch
,
num_tokens_no_spec_gpu
:
torch
.
Tensor
,
active_idx_cpu
:
torch
.
Tensor
,
active_idx_gpu
:
torch
.
Tensor
,
n_active
:
int
,
device
:
torch
.
device
,
_pinned_val_buf
:
torch
.
Tensor
,
)
->
None
:
"""Batch-sync GPU sequence lengths from CPU source of truth.
Inputs:
input_batch: Batch container with CPU length tensor.
num_tokens_no_spec_gpu: Destination GPU length tensor.
active_idx_cpu: Active request indices on CPU.
active_idx_gpu: Active request indices on GPU.
n_active: Number of active requests.
device: Target CUDA device.
_pinned_val_buf: Resident pinned int32 staging buffer.
Outputs:
None (updates num_tokens_no_spec_gpu in-place).
"""
src_cpu
=
input_batch
.
num_tokens_no_spec_cpu_tensor
vals
=
_pinned_val_buf
[:
n_active
]
vals
.
copy_
(
src_cpu
.
index_select
(
0
,
active_idx_cpu
))
num_tokens_no_spec_gpu
.
index_copy_
(
0
,
active_idx_gpu
,
vals
.
to
(
device
=
device
,
non_blocking
=
True
),
)
def
copy_num_valid_draft_tokens
(
num_valid_draft_tokens_cpu
:
torch
.
Tensor
,
num_valid_draft_tokens_copy_stream
:
torch
.
cuda
.
Stream
,
num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
,
num_valid_draft_tokens
:
torch
.
Tensor
|
None
,
batch_size
:
int
,
)
->
None
:
"""
Async D2H copy of per-request valid draft counts.
"""
if
num_valid_draft_tokens
is
None
:
return
num_reqs_to_copy
=
min
(
batch_size
,
num_valid_draft_tokens
.
shape
[
0
])
if
num_reqs_to_copy
<=
0
:
return
default_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
num_valid_draft_tokens_copy_stream
):
num_valid_draft_tokens_copy_stream
.
wait_stream
(
default_stream
)
num_valid_draft_tokens_cpu
[:
num_reqs_to_copy
].
copy_
(
num_valid_draft_tokens
[:
num_reqs_to_copy
],
non_blocking
=
True
)
num_valid_draft_tokens_event
.
record
()
vllm/v1/worker/gpu_input_batch.py
View file @
a6be75db
...
@@ -127,7 +127,13 @@ class InputBatch:
...
@@ -127,7 +127,13 @@ class InputBatch:
# allocation if max_model_len is big.
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self
.
req_prompt_embeds
:
dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
req_prompt_embeds
:
dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
num_tokens_no_spec
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens_no_spec_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
num_tokens_no_spec
=
self
.
num_tokens_no_spec_cpu_tensor
.
numpy
()
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu_tensor
=
torch
.
zeros
(
self
.
num_computed_tokens_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,),
(
max_num_reqs
,),
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
a6be75db
...
@@ -10,7 +10,7 @@ from collections import defaultdict
...
@@ -10,7 +10,7 @@ from collections import defaultdict
from
collections.abc
import
Iterable
,
Iterator
,
Sequence
from
collections.abc
import
Iterable
,
Iterator
,
Sequence
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
copy
import
copy
,
deepcopy
from
copy
import
copy
,
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
replace
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
...
@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
...
@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer_gpu
import
(
NgramProposerGPU
,
copy_num_valid_draft_tokens
,
update_ngram_gpu_tensors_incremental
,
update_scheduler_for_invalid_drafts
,
)
from
vllm.v1.spec_decode.suffix_decoding
import
SuffixDecodingProposer
from
vllm.v1.spec_decode.suffix_decoding
import
SuffixDecodingProposer
from
vllm.v1.structured_output.utils
import
apply_grammar_bitmask
from
vllm.v1.structured_output.utils
import
apply_grammar_bitmask
from
vllm.v1.utils
import
CpuGpuBuffer
,
record_function_or_nullcontext
from
vllm.v1.utils
import
CpuGpuBuffer
,
record_function_or_nullcontext
...
@@ -424,7 +430,7 @@ class GPUModelRunner(
...
@@ -424,7 +430,7 @@ class GPUModelRunner(
# Broadcast PP output for external_launcher (torchrun)
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# to make sure we are synced across pp ranks
# TODO: Support overlapping mi
r
co-batches
# TODO: Support overlapping mic
r
o-batches
# https://github.com/vllm-project/vllm/issues/18019
# https://github.com/vllm-project/vllm/issues/18019
self
.
broadcast_pp_output
=
(
self
.
broadcast_pp_output
=
(
self
.
parallel_config
.
distributed_executor_backend
==
"external_launcher"
self
.
parallel_config
.
distributed_executor_backend
==
"external_launcher"
...
@@ -493,6 +499,7 @@ class GPUModelRunner(
...
@@ -493,6 +499,7 @@ class GPUModelRunner(
if
self
.
speculative_config
and
get_pp_group
().
is_last_rank
:
if
self
.
speculative_config
and
get_pp_group
().
is_last_rank
:
self
.
drafter
:
(
self
.
drafter
:
(
NgramProposer
# noqa: F823
NgramProposer
# noqa: F823
|
NgramProposerGPU
|
SuffixDecodingProposer
|
SuffixDecodingProposer
|
EagleProposer
|
EagleProposer
|
DraftModelProposer
|
DraftModelProposer
...
@@ -509,6 +516,23 @@ class GPUModelRunner(
...
@@ -509,6 +516,23 @@ class GPUModelRunner(
device
=
self
.
device
,
device
=
self
.
device
,
runner
=
self
,
runner
=
self
,
)
)
elif
self
.
speculative_config
.
use_ngram_gpu
():
self
.
drafter
=
NgramProposerGPU
(
self
.
vllm_config
,
self
.
device
,
self
)
self
.
num_tokens_no_spec_gpu
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
token_ids_gpu_tensor
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
_ngram_pinned_idx_buf
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
long
,
pin_memory
=
True
)
self
.
_ngram_pinned_val_buf
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
elif
self
.
speculative_config
.
method
==
"suffix"
:
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
elif
self
.
speculative_config
.
use_eagle
():
...
@@ -564,7 +588,7 @@ class GPUModelRunner(
...
@@ -564,7 +588,7 @@ class GPUModelRunner(
)
)
self
.
input_batch
=
InputBatch
(
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
# We need to use the encoder length for encoder-decoer
# We need to use the encoder length for encoder-deco
d
er
# because of KV cache for cross-attention.
# because of KV cache for cross-attention.
max_model_len
=
max
(
self
.
max_model_len
,
self
.
max_encoder_len
),
max_model_len
=
max
(
self
.
max_model_len
,
self
.
max_encoder_len
),
max_num_batched_tokens
=
self
.
max_num_tokens
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
...
@@ -721,6 +745,21 @@ class GPUModelRunner(
...
@@ -721,6 +745,21 @@ class GPUModelRunner(
# Cached outputs.
# Cached outputs.
self
.
_draft_token_ids
:
list
[
list
[
int
]]
|
torch
.
Tensor
|
None
=
None
self
.
_draft_token_ids
:
list
[
list
[
int
]]
|
torch
.
Tensor
|
None
=
None
# N-gram GPU path: async D2H buffer/event for per-request valid draft counts.
self
.
_num_valid_draft_tokens
:
torch
.
Tensor
|
None
=
None
self
.
_num_valid_draft_tokens_cpu
:
torch
.
Tensor
|
None
=
None
self
.
_num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
|
None
=
None
self
.
_num_valid_draft_tokens_copy_stream
:
torch
.
cuda
.
Stream
|
None
=
None
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
self
.
_num_valid_draft_tokens_cpu
=
torch
.
empty
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
pin_memory
=
self
.
pin_memory
)
self
.
_num_valid_draft_tokens_event
=
torch
.
cuda
.
Event
()
self
.
_num_valid_draft_tokens_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
_draft_token_req_ids
:
list
[
str
]
|
None
=
None
self
.
_draft_token_req_ids
:
list
[
str
]
|
None
=
None
self
.
transfer_event
=
torch
.
Event
()
self
.
transfer_event
=
torch
.
Event
()
self
.
sampled_token_ids_pinned_cpu
=
torch
.
empty
(
self
.
sampled_token_ids_pinned_cpu
=
torch
.
empty
(
...
@@ -992,6 +1031,13 @@ class GPUModelRunner(
...
@@ -992,6 +1031,13 @@ class GPUModelRunner(
for
req_id
in
unscheduled_req_ids
:
for
req_id
in
unscheduled_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
self
.
input_batch
.
remove_request
(
req_id
)
is_ngram_gpu
=
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
)
if
is_ngram_gpu
:
ngram_gpu_new_reqs
:
list
[
CachedRequestState
]
=
[]
reqs_to_add
:
list
[
CachedRequestState
]
=
[]
reqs_to_add
:
list
[
CachedRequestState
]
=
[]
# Add new requests to the cached states.
# Add new requests to the cached states.
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
...
@@ -1054,12 +1100,31 @@ class GPUModelRunner(
...
@@ -1054,12 +1100,31 @@ class GPUModelRunner(
self
.
_init_xdrope_positions
(
req_state
)
self
.
_init_xdrope_positions
(
req_state
)
reqs_to_add
.
append
(
req_state
)
reqs_to_add
.
append
(
req_state
)
# Track new requests for ngram_gpu full tensor copy
if
is_ngram_gpu
:
ngram_gpu_new_reqs
.
append
(
req_state
)
# Update the states of the running/resumed requests.
# Update the states of the running/resumed requests.
is_last_rank
=
get_pp_group
().
is_last_rank
is_last_rank
=
get_pp_group
().
is_last_rank
req_data
=
scheduler_output
.
scheduled_cached_reqs
req_data
=
scheduler_output
.
scheduled_cached_reqs
scheduled_spec_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
scheduled_spec_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
# Save scheduler-allocated spec lengths before trimming so
# prev_num_draft_len keeps the optimistic count for rejection correction.
original_num_spec_per_req
:
dict
[
str
,
int
]
=
{}
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
for
req_id
,
toks
in
scheduled_spec_tokens
.
items
():
original_num_spec_per_req
[
req_id
]
=
len
(
toks
)
update_scheduler_for_invalid_drafts
(
self
.
_num_valid_draft_tokens_event
,
self
.
_num_valid_draft_tokens_cpu
,
scheduler_output
,
self
.
input_batch
.
req_id_to_index
,
)
# Wait until valid_sampled_tokens_count is copied to cpu,
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count
=
self
.
_get_valid_sampled_token_count
()
valid_sampled_token_count
=
self
.
_get_valid_sampled_token_count
()
...
@@ -1076,13 +1141,13 @@ class GPUModelRunner(
...
@@ -1076,13 +1141,13 @@ class GPUModelRunner(
# prev_num_draft_len is used in async scheduling mode with
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# fi
r
st step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt length),
# second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain
# num_computed_tokens in first step and second step does
n
't contain
# the spec tokens length, but in third step it contains the
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
# when prev_num_draft_len > 0.
...
@@ -1096,6 +1161,9 @@ class GPUModelRunner(
...
@@ -1096,6 +1161,9 @@ class GPUModelRunner(
num_computed_tokens
-=
num_rejected
num_computed_tokens
-=
num_rejected
req_state
.
output_token_ids
.
extend
([
-
1
]
*
num_accepted
)
req_state
.
output_token_ids
.
extend
([
-
1
]
*
num_accepted
)
if
is_ngram_gpu
and
num_accepted
>
0
and
req_index
is
not
None
:
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
+=
num_accepted
# Update the cached states.
# Update the cached states.
req_state
.
num_computed_tokens
=
num_computed_tokens
req_state
.
num_computed_tokens
=
num_computed_tokens
...
@@ -1156,6 +1224,9 @@ class GPUModelRunner(
...
@@ -1156,6 +1224,9 @@ class GPUModelRunner(
req_state
.
output_token_ids
=
resumed_token_ids
[
-
num_output_tokens
:]
req_state
.
output_token_ids
=
resumed_token_ids
[
-
num_output_tokens
:]
reqs_to_add
.
append
(
req_state
)
reqs_to_add
.
append
(
req_state
)
# Track resumed requests for ngram_gpu full tensor copy
if
is_ngram_gpu
:
ngram_gpu_new_reqs
.
append
(
req_state
)
continue
continue
# Update the persistent batch.
# Update the persistent batch.
...
@@ -1176,6 +1247,11 @@ class GPUModelRunner(
...
@@ -1176,6 +1247,11 @@ class GPUModelRunner(
# Add spec_token_ids to token_ids_cpu.
# Add spec_token_ids to token_ids_cpu.
self
.
input_batch
.
update_req_spec_token_ids
(
req_state
,
scheduled_spec_tokens
)
self
.
input_batch
.
update_req_spec_token_ids
(
req_state
,
scheduled_spec_tokens
)
# Restore scheduler-side draft count after ngram trimming.
if
original_num_spec_per_req
:
orig
=
original_num_spec_per_req
.
get
(
req_id
,
0
)
if
orig
!=
req_state
.
prev_num_draft_len
:
req_state
.
prev_num_draft_len
=
orig
# Add the new or resumed requests to the persistent batch.
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
# The smaller empty indices are filled first.
...
@@ -1190,6 +1266,18 @@ class GPUModelRunner(
...
@@ -1190,6 +1266,18 @@ class GPUModelRunner(
# Refresh batch metadata with any pending updates.
# Refresh batch metadata with any pending updates.
self
.
input_batch
.
refresh_metadata
()
self
.
input_batch
.
refresh_metadata
()
# Incrementally update ngram_gpu tensors after batch is stable
if
is_ngram_gpu
:
update_ngram_gpu_tensors_incremental
(
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
ngram_gpu_new_reqs
,
self
.
device
,
_pinned_idx_buf
=
self
.
_ngram_pinned_idx_buf
,
_pinned_val_buf
=
self
.
_ngram_pinned_val_buf
,
)
def
_update_states_after_model_execute
(
def
_update_states_after_model_execute
(
self
,
output_token_ids
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
self
,
output_token_ids
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
)
->
None
:
...
@@ -3412,6 +3500,23 @@ class GPUModelRunner(
...
@@ -3412,6 +3500,23 @@ class GPUModelRunner(
else
:
else
:
logger
.
error
(
"RoutedExpertsCapturer not initialized."
)
logger
.
error
(
"RoutedExpertsCapturer not initialized."
)
# If ngram_gpu is used, we need to copy the scheduler_output to avoid
# the modification has influence on the scheduler_output in engine core process.
# The replace is much faster than deepcopy.
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
num_scheduled_tokens_copy
=
scheduler_output
.
num_scheduled_tokens
.
copy
()
spec_decode_tokens_copy
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
copy
()
)
scheduler_output
=
replace
(
scheduler_output
,
num_scheduled_tokens
=
num_scheduled_tokens_copy
,
scheduled_spec_decode_tokens
=
spec_decode_tokens_copy
,
)
if
scheduler_output
.
preempted_req_ids
and
has_kv_transfer_group
():
if
scheduler_output
.
preempted_req_ids
and
has_kv_transfer_group
():
get_kv_transfer_group
().
handle_preemptions
(
get_kv_transfer_group
().
handle_preemptions
(
scheduler_output
.
preempted_req_ids
scheduler_output
.
preempted_req_ids
...
@@ -3825,6 +3930,32 @@ class GPUModelRunner(
...
@@ -3825,6 +3930,32 @@ class GPUModelRunner(
self
.
_copy_valid_sampled_token_count
(
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
next_token_ids
,
valid_sampled_tokens_count
)
)
self
.
_draft_token_ids
=
torch
.
zeros
(
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
).
expand
(
len
(
self
.
input_batch
.
req_ids
),
self
.
num_spec_tokens
)
self
.
_copy_draft_token_ids_to_cpu
(
scheduler_output
,
zeros_only
=
True
)
elif
(
spec_config
.
use_ngram_gpu
()
and
not
spec_config
.
disable_padded_drafter_batch
):
assert
isinstance
(
self
.
drafter
,
NgramProposerGPU
)
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
input_fits_in_drafter
:
propose_draft_token_ids
(
sampled_token_ids
)
elif
self
.
valid_sampled_token_count_event
is
not
None
:
assert
spec_decode_common_attn_metadata
is
not
None
next_token_ids
,
valid_sampled_tokens_count
,
_
=
(
self
.
drafter
.
update_token_ids_ngram
(
sampled_token_ids
,
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
self
.
discard_request_mask
.
gpu
,
)
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
# just use zeros for the draft tokens.
self
.
_draft_token_ids
=
torch
.
zeros
(
self
.
_draft_token_ids
=
torch
.
zeros
(
...
@@ -4064,6 +4195,52 @@ class GPUModelRunner(
...
@@ -4064,6 +4195,52 @@ class GPUModelRunner(
self
.
input_batch
.
token_ids_cpu
,
self
.
input_batch
.
token_ids_cpu
,
slot_mappings
=
slot_mappings
,
slot_mappings
=
slot_mappings
,
)
)
if
isinstance
(
self
.
drafter
,
NgramProposer
):
assert
isinstance
(
sampled_token_ids
,
list
),
(
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids
=
self
.
drafter
.
propose
(
sampled_token_ids
,
self
.
input_batch
.
num_tokens_no_spec
,
self
.
input_batch
.
token_ids_cpu
,
)
elif
spec_config
.
use_ngram_gpu
():
assert
isinstance
(
self
.
drafter
,
NgramProposerGPU
)
(
next_token_ids
,
valid_sampled_tokens_count
,
valid_sampled_token_ids_gpu
,
)
=
self
.
drafter
.
update_token_ids_ngram
(
sampled_token_ids
,
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
self
.
discard_request_mask
.
gpu
,
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
batch_size
=
next_token_ids
.
shape
[
0
]
draft_token_ids
,
num_valid_draft_tokens
=
self
.
drafter
.
propose
(
self
.
num_tokens_no_spec_gpu
[:
batch_size
],
self
.
token_ids_gpu_tensor
[:
batch_size
],
valid_sampled_token_ids_gpu
,
valid_sampled_tokens_count
,
)
# Cache valid draft counts for scheduler-side trimming.
self
.
_num_valid_draft_tokens
=
num_valid_draft_tokens
# Async D2H copy on a dedicated stream.
copy_num_valid_draft_tokens
(
self
.
_num_valid_draft_tokens_cpu
,
self
.
_num_valid_draft_tokens_copy_stream
,
self
.
_num_valid_draft_tokens_event
,
self
.
_num_valid_draft_tokens
,
self
.
input_batch
.
num_reqs
,
)
elif
spec_config
.
method
==
"suffix"
:
elif
spec_config
.
method
==
"suffix"
:
assert
isinstance
(
sampled_token_ids
,
list
)
assert
isinstance
(
sampled_token_ids
,
list
)
assert
isinstance
(
self
.
drafter
,
SuffixDecodingProposer
)
assert
isinstance
(
self
.
drafter
,
SuffixDecodingProposer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment