Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a6be75db
Unverified
Commit
a6be75db
authored
Mar 08, 2026
by
PatchyTIS
Committed by
GitHub
Mar 07, 2026
Browse files
[Core] NGram GPU Implementation compatible with Async Scheduler (#29184)
parent
ee54f9cd
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
940 additions
and
12 deletions
+940
-12
tests/v1/e2e/test_async_scheduling.py
tests/v1/e2e/test_async_scheduling.py
+40
-3
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+28
-0
vllm/compilation/backends.py
vllm/compilation/backends.py
+7
-0
vllm/config/speculative.py
vllm/config/speculative.py
+9
-1
vllm/config/vllm.py
vllm/config/vllm.py
+5
-2
vllm/tool_parsers/hermes_tool_parser.py
vllm/tool_parsers/hermes_tool_parser.py
+2
-0
vllm/v1/spec_decode/ngram_proposer_gpu.py
vllm/v1/spec_decode/ngram_proposer_gpu.py
+660
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+7
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+182
-5
No files found.
tests/v1/e2e/test_async_scheduling.py
View file @
a6be75db
...
...
@@ -98,7 +98,7 @@ def test_without_spec_decoding(
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
16
)
def
test_with_spec_decoding
(
sample_json_schema
,
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_with_
eagle3_
spec_decoding
(
sample_json_schema
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
...
...
@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
)
def
test_with_ngram_gpu_spec_decoding
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test ngram_gpu speculative decoding with different configurations.
This test specifically validates ngram_gpu behavior with various:
- Number of speculative tokens (2-6)
- Prompt lookup window sizes (min/max)
- Async scheduling enabled (as in production)
- Different executors and chunking settings
"""
# Variant with larger speculation window
ngram_gpu_config
=
{
"method"
:
"ngram_gpu"
,
"num_speculative_tokens"
:
3
,
"prompt_lookup_max"
:
3
,
"prompt_lookup_min"
:
2
,
}
# Test configurations covering various scenarios
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs
=
[
(
False
,
"mp"
,
False
,
None
,
False
),
(
False
,
"mp"
,
False
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
False
,
ngram_gpu_config
,
True
),
(
False
,
"mp"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"uni"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
True
,
ngram_gpu_config
,
True
),
]
# Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
# and ngram_gpu doesn't require a specific draft model
run_tests
(
monkeypatch
,
MODEL
,
test_configs
,
[{}])
@
dynamo_config
.
patch
(
cache_size_limit
=
16
)
def
run_tests
(
monkeypatch
:
pytest
.
MonkeyPatch
,
...
...
@@ -282,11 +318,12 @@ def run_test(
else
dict
(
gpu_memory_utilization
=
0.9
)
)
spec_mml
=
(
spec_config
or
{}).
get
(
"max_model_len"
)
spec_method
=
(
spec_config
or
{}).
get
(
"method"
,
"none"
)
test_config
=
(
f
"executor=
{
executor
}
, preemption=
{
test_preemption
}
, "
f
"async_sched=
{
async_scheduling
}
, "
f
"chunk_prefill=
{
test_prefill_chunking
}
, "
f
"spec_decoding=
{
spec_decoding
}
, spec_mml=
{
spec_mml
}
"
f
"spec_decoding=
{
spec_decoding
}
,
spec_method=
{
spec_method
}
,
spec_mml=
{
spec_mml
}
"
)
print
(
"-"
*
80
)
print
(
f
"---- TESTING
{
test_str
}
:
{
test_config
}
"
)
...
...
@@ -294,7 +331,7 @@ def run_test(
with
VllmRunner
(
model
,
max_model_len
=
512
,
max_model_len
=
4096
,
enable_chunked_prefill
=
test_prefill_chunking
,
# Force prefill chunking
max_num_batched_tokens
=
48
if
test_prefill_chunking
else
None
,
...
...
tests/v1/e2e/test_spec_decode.py
View file @
a6be75db
...
...
@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness(
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
],
ids
=
[
"async"
])
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
20
)
def
test_ngram_gpu_default_with_async_scheduling
(
async_scheduling
:
bool
,
):
"""
Test ngram_gpu speculative decoding (k=3) correctness with and without
async scheduling, validated via GSM8K accuracy.
Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%).
"""
qwen3_model
=
"Qwen/Qwen3-8B"
spec_llm
=
LLM
(
model
=
qwen3_model
,
speculative_config
=
{
"method"
:
"ngram_gpu"
,
"prompt_lookup_max"
:
3
,
"prompt_lookup_min"
:
2
,
"num_speculative_tokens"
:
2
,
},
max_model_len
=
4096
,
async_scheduling
=
async_scheduling
,
)
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
0.8
)
del
spec_llm
cleanup_dist_env_and_memory
()
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
20
)
def
test_suffix_decoding_acceptance
(
...
...
vllm/compilation/backends.py
View file @
a6be75db
...
...
@@ -907,6 +907,13 @@ class VllmBackend:
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache
=
not
is_compile_cache_enabled
(
self
.
inductor_config
)
# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
is_ngram_gpu_enabled
=
(
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
speculative_config
.
use_ngram_gpu
()
)
disable_cache
=
disable_cache
or
is_ngram_gpu_enabled
if
disable_cache
:
logger
.
info_once
(
"vLLM's torch.compile cache is disabled."
,
scope
=
"local"
)
else
:
...
...
vllm/config/speculative.py
View file @
a6be75db
...
...
@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
"step3p5_mtp"
,
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
]
NgramGPUTypes
=
Literal
[
"ngram_gpu"
]
SpeculativeMethod
=
Literal
[
"ngram"
,
"medusa"
,
...
...
@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[
"draft_model"
,
"suffix"
,
EagleModelTypes
,
NgramGPUTypes
,
]
...
...
@@ -364,6 +366,8 @@ class SpeculativeConfig:
self
.
quantization
=
self
.
target_model_config
.
quantization
elif
self
.
method
in
(
"ngram"
,
"[ngram]"
):
self
.
model
=
"ngram"
elif
self
.
method
==
"ngram_gpu"
:
self
.
model
=
"ngram_gpu"
elif
self
.
method
==
"suffix"
:
self
.
model
=
"suffix"
elif
self
.
method
==
"extract_hidden_states"
:
...
...
@@ -374,8 +378,9 @@ class SpeculativeConfig:
)
if
self
.
method
in
(
"ngram"
,
"[ngram]"
):
# Unified to "ngram" internally
self
.
method
=
"ngram"
if
self
.
method
in
(
"ngram"
,
"ngram_gpu"
):
# Set default values if not provided
if
self
.
prompt_lookup_min
is
None
and
self
.
prompt_lookup_max
is
None
:
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
...
...
@@ -832,6 +837,9 @@ class SpeculativeConfig:
def
uses_extract_hidden_states
(
self
)
->
bool
:
return
self
.
method
==
"extract_hidden_states"
def
use_ngram_gpu
(
self
)
->
bool
:
return
self
.
method
==
"ngram_gpu"
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
model
=
(
...
...
vllm/config/vllm.py
View file @
a6be75db
...
...
@@ -41,7 +41,7 @@ from .offload import OffloadConfig
from
.parallel
import
ParallelConfig
from
.profiler
import
ProfilerConfig
from
.scheduler
import
SchedulerConfig
from
.speculative
import
EagleModelTypes
,
SpeculativeConfig
from
.speculative
import
EagleModelTypes
,
NgramGPUTypes
,
SpeculativeConfig
from
.structured_outputs
import
StructuredOutputsConfig
from
.utils
import
SupportsHash
,
config
,
replace
from
.weight_transfer
import
WeightTransferConfig
...
...
@@ -696,11 +696,13 @@ class VllmConfig:
if
self
.
speculative_config
is
not
None
:
if
(
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
and
self
.
speculative_config
.
method
not
in
get_args
(
NgramGPUTypes
)
and
self
.
speculative_config
.
method
!=
"draft_model"
):
raise
ValueError
(
"Currently, async scheduling is only supported "
"with EAGLE/MTP/Draft Model kind of speculative decoding."
"with EAGLE/MTP/Draft Model/NGram GPU kind of "
"speculative decoding"
)
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
raise
ValueError
(
...
...
@@ -718,6 +720,7 @@ class VllmConfig:
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
and
self
.
speculative_config
.
method
not
in
get_args
(
NgramGPUTypes
)
):
logger
.
warning_once
(
"Async scheduling not supported with %s-based "
...
...
vllm/tool_parsers/hermes_tool_parser.py
View file @
a6be75db
...
...
@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser):
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
assert
current_tool_call
is
not
None
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
logger
.
debug
(
"diffing old arguments: %s"
,
prev_arguments
)
...
...
@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser):
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
assert
isinstance
(
current_tool_call
,
dict
)
if
self
.
current_tool_id
==
len
(
self
.
prev_tool_call_arr
)
-
1
:
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
current_tool_call
else
:
...
...
vllm/v1/spec_decode/ngram_proposer_gpu.py
0 → 100644
View file @
a6be75db
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.
This version uses a fully vectorized approach with unfold and argmax for
finding the first match across all sequences in parallel.
"""
import
torch
from
torch
import
nn
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
VllmConfig
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.utils
import
record_function_or_nullcontext
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
@
support_torch_compile
()
class
NgramGPUKernel
(
nn
.
Module
):
"""GPU-accelerated N-gram proposer using fully async tensor operations."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
device
:
torch
.
device
=
"cuda"
):
super
().
__init__
()
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_min
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_max
is
not
None
self
.
min_n
=
vllm_config
.
speculative_config
.
prompt_lookup_min
self
.
max_n
=
vllm_config
.
speculative_config
.
prompt_lookup_max
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_seqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
device
=
device
def
_find_first_and_extract_all_n_parallel
(
self
,
token_ids
:
torch
.
Tensor
,
seq_lengths
:
torch
.
Tensor
,
min_ngram_len
:
int
,
max_ngram_len
:
int
,
num_draft_tokens
:
int
,
)
->
torch
.
Tensor
:
"""
Find suffix n-gram matches and extract following tokens.
Searches for the earliest prior occurrence of the trailing n-gram,
tries multiple lengths, and picks the longest valid match.
Args:
token_ids: Token IDs for each sequence
seq_lengths: Actual length of each sequence (excluding padding)
min_ngram_len: Minimum n-gram size to search for (e.g., 2)
max_ngram_len: Maximum n-gram size to search for (e.g., 5)
num_draft_tokens: Number of tokens to extract after match (k)
Returns:
Draft token predictions; -1 means invalid/no match.
"""
batch_size
=
token_ids
.
shape
[
0
]
max_seq_len
=
token_ids
.
shape
[
1
]
device
=
token_ids
.
device
num_ngram_sizes
=
max_ngram_len
-
min_ngram_len
+
1
# All n-gram sizes to try.
ngram_lengths
=
torch
.
arange
(
min_ngram_len
,
max_ngram_len
+
1
,
device
=
device
)
batch_indices
=
torch
.
arange
(
batch_size
,
device
=
device
)
# Earliest match per (sequence, ngram_len); -1 means no match.
first_match_positions
=
torch
.
full
(
(
batch_size
,
num_ngram_sizes
),
-
1
,
dtype
=
torch
.
long
,
device
=
device
)
for
i
,
ngram_len
in
enumerate
(
range
(
min_ngram_len
,
max_ngram_len
+
1
)):
# Sliding windows of size ngram_len; unfold is O(1) view.
search_windows
=
token_ids
.
unfold
(
1
,
ngram_len
,
1
)
num_windows
=
search_windows
.
shape
[
1
]
# Trailing suffix (last ngram_len tokens) for each sequence.
suffix_starts
=
seq_lengths
-
ngram_len
suffix_indices
=
suffix_starts
.
unsqueeze
(
1
)
+
torch
.
arange
(
ngram_len
,
device
=
device
)
suffix
=
torch
.
gather
(
token_ids
,
1
,
suffix_indices
.
clamp
(
min
=
0
))
# Window matches for each sequence.
matches
=
(
search_windows
==
suffix
.
unsqueeze
(
1
)).
all
(
dim
=-
1
)
# Match must leave room for at least one draft token.
max_valid_suffix_start
=
seq_lengths
-
ngram_len
-
1
window_positions
=
torch
.
arange
(
num_windows
,
device
=
device
)
valid_mask
=
window_positions
<=
max_valid_suffix_start
.
unsqueeze
(
1
)
final_matches
=
matches
&
valid_mask
# Find earliest match (argmax=0 when empty; verify with has_match).
first_match_idx
=
torch
.
argmax
(
final_matches
.
int
(),
dim
=
1
)
has_match
=
final_matches
[
batch_indices
,
first_match_idx
]
# Store valid match positions (window index = position).
first_match_positions
[:,
i
]
=
torch
.
where
(
has_match
,
first_match_idx
,
-
1
)
# Select the longest n-gram with a match.
best_ngram_idx
=
(
first_match_positions
>=
0
).
int
().
flip
(
dims
=
[
1
]).
argmax
(
dim
=
1
)
best_ngram_idx
=
num_ngram_sizes
-
1
-
best_ngram_idx
# Flip back
# Match position for the best n-gram.
best_match_pos
=
first_match_positions
[
batch_indices
,
best_ngram_idx
]
# Avoid data-dependent branching.
has_any_match
=
best_match_pos
>=
0
# Length of the best matching n-gram.
best_ngram_lengths
=
ngram_lengths
[
best_ngram_idx
]
# Start position right after the matched suffix.
draft_start
=
torch
.
where
(
has_any_match
,
best_match_pos
+
best_ngram_lengths
,
torch
.
zeros_like
(
best_match_pos
),
)
tokens_available
=
seq_lengths
-
draft_start
# Gather indices for draft tokens.
draft_indices
=
draft_start
.
unsqueeze
(
1
)
+
torch
.
arange
(
num_draft_tokens
,
device
=
device
)
draft_indices
=
draft_indices
.
clamp
(
min
=
0
,
max
=
max_seq_len
-
1
)
# Extract draft tokens; gather always runs.
draft_tokens
=
torch
.
gather
(
token_ids
,
1
,
draft_indices
)
# Mask positions beyond available tokens.
position_indices
=
torch
.
arange
(
num_draft_tokens
,
device
=
device
).
unsqueeze
(
0
)
valid_positions
=
position_indices
<
tokens_available
.
unsqueeze
(
1
)
draft_tokens
=
torch
.
where
(
valid_positions
,
draft_tokens
,
torch
.
full_like
(
draft_tokens
,
-
1
),
)
# If no match, mask all positions.
draft_tokens
=
torch
.
where
(
has_any_match
.
unsqueeze
(
1
),
draft_tokens
,
torch
.
full_like
(
draft_tokens
,
-
1
),
)
return
draft_tokens
def
forward
(
self
,
num_tokens_no_spec
:
torch
.
Tensor
,
token_ids_gpu
:
torch
.
Tensor
,
combined_mask
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Forward pass for N-gram proposal using GPU tensor operations.
Args:
num_tokens_no_spec: Number of tokens for each sequence [batch_size]
token_ids_gpu: Token IDs [batch_size, max_len]
combined_mask: Whether each sequence is valid for spec decode [batch_size]
Returns:
draft_tokens: [batch_size, k] on GPU
num_valid_draft_tokens: [batch_size] int32 on GPU, count of
leading valid (non -1) tokens per request.
"""
device
=
token_ids_gpu
.
device
# Infer batch size to preserve dynamic shape.
actual_batch_size
=
token_ids_gpu
.
shape
[
0
]
# Allocate in forward so torch.compile can optimize.
# NOTE(patchy): Do NOT pre-allocate this as a buffer
# it breaks torch.compile
draft_tokens
=
torch
.
full
(
(
actual_batch_size
,
self
.
k
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
results
=
self
.
_find_first_and_extract_all_n_parallel
(
token_ids_gpu
,
num_tokens_no_spec
,
min_ngram_len
=
self
.
min_n
,
max_ngram_len
=
self
.
max_n
,
num_draft_tokens
=
self
.
k
,
)
draft_tokens
=
torch
.
where
(
combined_mask
.
unsqueeze
(
1
),
results
,
-
1
)
# Count leading contiguous valid (non -1) tokens per request.
is_valid
=
draft_tokens
!=
-
1
# [batch, k]
cum_valid
=
is_valid
.
int
().
cumsum
(
dim
=
1
)
# [batch, k]
positions
=
torch
.
arange
(
1
,
self
.
k
+
1
,
device
=
device
).
unsqueeze
(
0
)
num_valid_draft_tokens
=
(
cum_valid
==
positions
).
int
().
sum
(
dim
=
1
)
return
draft_tokens
,
num_valid_draft_tokens
def
load_model
(
self
,
*
args
,
**
kwargs
):
"""No model to load for N-gram proposer."""
pass
class
NgramProposerGPU
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
):
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_min
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_max
is
not
None
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
[
"none"
],
splitting_ops
=
[],
compile_sizes
=
[],
inductor_compile_config
=
{
"enable_auto_functionalized_v2"
:
False
,
"max_autotune"
:
True
,
"aggressive_fusion"
:
True
,
"triton.autotune_pointwise"
:
True
,
"coordinate_descent_tuning"
:
True
,
"use_mixed_mm"
:
False
,
},
cudagraph_mode
=
CUDAGraphMode
.
NONE
,
)
model_config
=
vllm_config
.
model_config
speculative_config
=
vllm_config
.
speculative_config
scheduler_config
=
vllm_config
.
scheduler_config
self
.
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
,
model_config
=
model_config
,
speculative_config
=
speculative_config
,
scheduler_config
=
scheduler_config
,
)
self
.
min_n
=
vllm_config
.
speculative_config
.
prompt_lookup_min
self
.
max_n
=
vllm_config
.
speculative_config
.
prompt_lookup_max
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_seqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
device
=
device
self
.
kernel
=
NgramGPUKernel
(
vllm_config
=
self
.
vllm_config
,
prefix
=
"ngram_gpu_kernel"
,
device
=
device
)
self
.
kernel
.
to
(
device
)
self
.
kernel
.
eval
()
self
.
_dummy_run
()
def
_dummy_run
(
self
):
token_ids
,
num_tokens
,
sampled_flags
,
valid_mask
=
self
.
_generate_dummy_data
(
batch_size
=
self
.
max_num_seqs
,
max_seq_len
=
self
.
max_model_len
,
pattern_len
=
self
.
k
,
device
=
self
.
device
,
)
combined_mask
=
sampled_flags
&
valid_mask
&
(
num_tokens
>=
self
.
min_n
)
for
_
in
range
(
3
):
with
set_forward_context
(
None
,
self
.
vllm_config
):
_
,
_
=
self
.
kernel
(
num_tokens
,
token_ids
,
combined_mask
)
def
_generate_dummy_data
(
self
,
batch_size
:
int
,
max_seq_len
:
int
,
pattern_len
:
int
,
device
:
str
=
"cuda"
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Generate random test data with n-gram repetitions.
Args:
batch_size: Number of sequences in the batch
max_seq_len: Maximum sequence length
pattern_len: Length of patterns to inject for matching
device: Device to place tensors on
Returns:
token_ids: [batch_size, max_seq_len] tensor
num_tokens: [batch_size] tensor
sampled_flags: [batch_size] bool tensor
valid_mask: [batch_size] bool tensor
"""
token_ids
=
torch
.
zeros
(
batch_size
,
max_seq_len
,
dtype
=
torch
.
int32
,
device
=
device
,
)
num_tokens
=
torch
.
randint
(
pattern_len
,
max_seq_len
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
sampled_flags
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
device
)
valid_mask
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
device
)
return
token_ids
,
num_tokens
,
sampled_flags
,
valid_mask
def
propose
(
self
,
num_tokens_no_spec
:
torch
.
Tensor
,
# [batch_size]
token_ids_gpu
:
torch
.
Tensor
,
# [batch_size, max_len]
valid_sampled_token_ids_gpu
:
torch
.
Tensor
,
# [batch_size, num_spec_tokens + 1]
valid_sampled_tokens_count
:
torch
.
Tensor
,
# [batch_size]
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Propose draft tokens using GPU-accelerated n-gram matching.
Scatter sampled tokens into `token_ids_gpu`, compute temporary
updated lengths, then run the kernel.
Args:
num_tokens_no_spec: Number of tokens per sequence (read-only)
token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
valid_sampled_tokens_count: Count of valid tokens per sequence
Returns:
draft_tokens: Proposed draft token IDs [batch_size, k]
num_valid_draft_tokens: Count of leading valid draft tokens
per request [batch_size]
"""
assert
token_ids_gpu
.
device
==
self
.
device
assert
num_tokens_no_spec
.
device
==
self
.
device
batch_size
=
num_tokens_no_spec
.
shape
[
0
]
max_seq_len
=
token_ids_gpu
.
shape
[
1
]
max_new_tokens
=
valid_sampled_token_ids_gpu
.
shape
[
1
]
# num_spec_tokens + 1
# Scatter newly sampled tokens into token_ids_gpu.
offsets
=
torch
.
arange
(
max_new_tokens
,
device
=
self
.
device
)
write_positions
=
num_tokens_no_spec
.
unsqueeze
(
1
)
+
offsets
.
unsqueeze
(
0
)
valid_write_mask
=
offsets
.
unsqueeze
(
0
)
<
valid_sampled_tokens_count
.
unsqueeze
(
1
)
in_bounds
=
write_positions
<
max_seq_len
scatter_mask
=
(
valid_write_mask
&
(
valid_sampled_token_ids_gpu
!=
-
1
)
&
in_bounds
)
write_positions_long
=
write_positions
.
clamp
(
max
=
max_seq_len
-
1
).
long
()
existing_values
=
token_ids_gpu
.
gather
(
1
,
write_positions_long
)
tokens_cast
=
valid_sampled_token_ids_gpu
.
to
(
token_ids_gpu
.
dtype
)
tokens_to_scatter
=
torch
.
where
(
scatter_mask
,
tokens_cast
,
existing_values
,
)
token_ids_gpu
.
scatter_
(
1
,
write_positions_long
,
tokens_to_scatter
)
num_tokens_tmp
=
num_tokens_no_spec
+
valid_sampled_tokens_count
# Compute validity masks.
sampled_flags
=
valid_sampled_tokens_count
>
0
valid_mask
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
)
with
set_forward_context
(
None
,
self
.
vllm_config
):
combined_mask
=
sampled_flags
&
valid_mask
&
(
num_tokens_tmp
>=
self
.
min_n
)
with
record_function_or_nullcontext
(
"ngram_proposer_gpu: kernel"
):
draft_tokens
,
num_valid_draft_tokens
=
self
.
kernel
(
num_tokens_tmp
,
token_ids_gpu
,
combined_mask
,
)
return
draft_tokens
,
num_valid_draft_tokens
def
update_token_ids_ngram
(
self
,
sampled_token_ids
:
torch
.
Tensor
|
list
[
list
[
int
]],
gpu_input_batch
:
InputBatch
,
token_ids_gpu
:
torch
.
Tensor
,
num_tokens_no_spec
:
torch
.
Tensor
,
discard_request_mask
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Prepare speculative decoding inputs on device:
compute next token ids and valid counts, honoring discarded requests
and rejected tokens, without CPU-GPU sync.
"""
num_reqs
=
gpu_input_batch
.
num_reqs
if
isinstance
(
sampled_token_ids
,
list
):
# When disable_padded_drafter_batch=True, sampled_token_ids is
# an irregular list[list[int]] where sublists may have different
# lengths (including empty lists for discarded requests).
# Pad all sublists to the same length with -1 before converting
# to tensor.
max_len
=
max
(
(
len
(
sublist
)
for
sublist
in
sampled_token_ids
),
default
=
0
,
)
# Ensure at least length 1 for tensor creation
max_len
=
max
(
max_len
,
1
)
padded_list
=
[
sublist
+
[
-
1
]
*
(
max_len
-
len
(
sublist
))
for
sublist
in
sampled_token_ids
]
sampled_token_ids
=
torch
.
tensor
(
padded_list
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
assert
isinstance
(
sampled_token_ids
,
torch
.
Tensor
),
(
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
)
# Backup last valid token before speculative tokens.
backup_indices
=
(
num_tokens_no_spec
[:
num_reqs
]
-
1
).
clamp
(
min
=
0
).
long
()
backup_next_token_ids
=
torch
.
gather
(
token_ids_gpu
[:
num_reqs
],
dim
=
1
,
index
=
backup_indices
.
unsqueeze
(
1
)
).
squeeze
(
1
)
valid_sampled_token_ids_gpu
=
sampled_token_ids
.
clone
()
# Invalidate sampled tokens for discarded requests.
discard_mask_expanded
=
discard_request_mask
[:
num_reqs
].
unsqueeze
(
1
)
valid_sampled_token_ids_gpu
.
masked_fill_
(
discard_mask_expanded
,
-
1
)
# Mask valid tokens within each request.
valid_mask
=
(
valid_sampled_token_ids_gpu
!=
-
1
)
&
(
valid_sampled_token_ids_gpu
<
gpu_input_batch
.
vocab_size
)
# Count valid tokens per request.
valid_sampled_tokens_count
=
valid_mask
.
sum
(
dim
=
1
)
# Rightmost valid index per row.
last_valid_indices
=
valid_sampled_tokens_count
-
1
last_valid_indices_safe
=
torch
.
clamp
(
last_valid_indices
,
min
=
0
)
# Last valid token from each row; undefined if none.
selected_tokens
=
torch
.
gather
(
valid_sampled_token_ids_gpu
,
1
,
last_valid_indices_safe
.
unsqueeze
(
1
)
).
squeeze
(
1
)
# Use last token if valid; otherwise fallback to backup.
next_token_ids
=
torch
.
where
(
last_valid_indices
!=
-
1
,
selected_tokens
,
backup_next_token_ids
,
)
return
next_token_ids
,
valid_sampled_tokens_count
,
valid_sampled_token_ids_gpu
def
load_model
(
self
,
*
args
,
**
kwargs
):
self
.
kernel
.
load_model
(
*
args
,
**
kwargs
)
def
update_scheduler_for_invalid_drafts
(
num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
,
num_valid_draft_tokens_cpu
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
,
req_id_to_index
:
dict
[
str
,
int
],
)
->
None
:
"""Trim invalid speculative slots using per-request valid draft counts.
Args:
num_valid_draft_tokens_event: Event for async D2H completion.
num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
scheduler_output: Scheduler metadata to update in-place.
req_id_to_index: Request-id to batch-index mapping.
"""
req_data
=
scheduler_output
.
scheduled_cached_reqs
num_valid_draft_tokens_event
.
synchronize
()
for
req_id
in
req_data
.
req_ids
:
req_index
=
req_id_to_index
.
get
(
req_id
)
if
req_index
is
None
:
continue
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
)
if
spec_token_ids
is
None
:
continue
scheduled_k
=
len
(
spec_token_ids
)
valid_k
=
int
(
num_valid_draft_tokens_cpu
[
req_index
].
item
())
valid_k
=
max
(
0
,
min
(
valid_k
,
scheduled_k
))
tokens_to_trim
=
scheduled_k
-
valid_k
scheduler_output
.
total_num_scheduled_tokens
-=
tokens_to_trim
scheduler_output
.
num_scheduled_tokens
[
req_id
]
-=
tokens_to_trim
if
valid_k
==
0
:
scheduler_output
.
scheduled_spec_decode_tokens
.
pop
(
req_id
,
None
)
else
:
scheduler_output
.
scheduled_spec_decode_tokens
[
req_id
]
=
spec_token_ids
[
:
valid_k
]
def
update_ngram_gpu_tensors_incremental
(
input_batch
:
InputBatch
,
token_ids_gpu_tensor
:
torch
.
Tensor
,
num_tokens_no_spec_gpu
:
torch
.
Tensor
,
new_reqs
:
list
[
CachedRequestState
],
device
:
torch
.
device
,
_pinned_idx_buf
:
torch
.
Tensor
,
_pinned_val_buf
:
torch
.
Tensor
,
)
->
None
:
"""Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
for ngram GPU proposer.
"""
prev_req_id_to_index
=
input_batch
.
prev_req_id_to_index
curr_req_id_to_index
=
input_batch
.
req_id_to_index
if
not
curr_req_id_to_index
:
return
active_indices
=
list
(
curr_req_id_to_index
.
values
())
n_active
=
len
(
active_indices
)
# Use resident pinned buffers to avoid per-call allocation.
active_idx_cpu
=
_pinned_idx_buf
[:
n_active
]
active_idx_cpu
.
copy_
(
torch
.
as_tensor
(
active_indices
,
dtype
=
torch
.
long
))
active_idx_gpu
=
active_idx_cpu
.
to
(
device
=
device
,
non_blocking
=
True
)
new_req_ids
=
{
req
.
req_id
for
req
in
new_reqs
}
# First run, no previous state.
if
prev_req_id_to_index
is
None
:
for
idx
in
active_indices
:
num_tokens
=
input_batch
.
num_tokens_no_spec
[
idx
]
if
num_tokens
>
0
:
token_ids_gpu_tensor
[
idx
,
:
num_tokens
].
copy_
(
input_batch
.
token_ids_cpu_tensor
[
idx
,
:
num_tokens
],
non_blocking
=
True
,
)
_sync_num_tokens
(
input_batch
,
num_tokens_no_spec_gpu
,
active_idx_cpu
,
active_idx_gpu
,
n_active
,
device
,
_pinned_val_buf
,
)
return
# Detect index changes for reorder.
reorder_src
:
list
[
int
]
=
[]
reorder_dst
:
list
[
int
]
=
[]
for
req_id
,
curr_idx
in
curr_req_id_to_index
.
items
():
if
req_id
in
new_req_ids
:
continue
prev_idx
=
prev_req_id_to_index
.
get
(
req_id
)
if
prev_idx
is
not
None
and
prev_idx
!=
curr_idx
:
reorder_src
.
append
(
prev_idx
)
reorder_dst
.
append
(
curr_idx
)
if
reorder_src
:
src_tensor
=
torch
.
tensor
(
reorder_src
,
dtype
=
torch
.
long
,
device
=
device
)
dst_tensor
=
torch
.
tensor
(
reorder_dst
,
dtype
=
torch
.
long
,
device
=
device
)
temp_token_ids
=
token_ids_gpu_tensor
[
src_tensor
].
clone
()
temp_num_tokens
=
num_tokens_no_spec_gpu
[
src_tensor
].
clone
()
token_ids_gpu_tensor
[
dst_tensor
]
=
temp_token_ids
num_tokens_no_spec_gpu
[
dst_tensor
]
=
temp_num_tokens
# Full copy for new/resumed requests.
for
req_state
in
new_reqs
:
new_req_idx
=
curr_req_id_to_index
.
get
(
req_state
.
req_id
)
if
new_req_idx
is
None
:
continue
num_tokens
=
input_batch
.
num_tokens_no_spec
[
new_req_idx
]
if
num_tokens
>
0
:
token_ids_gpu_tensor
[
new_req_idx
,
:
num_tokens
].
copy_
(
input_batch
.
token_ids_cpu_tensor
[
new_req_idx
,
:
num_tokens
],
non_blocking
=
True
,
)
# Always batch-sync sequence lengths from CPU for ALL active requests.
_sync_num_tokens
(
input_batch
,
num_tokens_no_spec_gpu
,
active_idx_cpu
,
active_idx_gpu
,
n_active
,
device
,
_pinned_val_buf
,
)
def
_sync_num_tokens
(
input_batch
:
InputBatch
,
num_tokens_no_spec_gpu
:
torch
.
Tensor
,
active_idx_cpu
:
torch
.
Tensor
,
active_idx_gpu
:
torch
.
Tensor
,
n_active
:
int
,
device
:
torch
.
device
,
_pinned_val_buf
:
torch
.
Tensor
,
)
->
None
:
"""Batch-sync GPU sequence lengths from CPU source of truth.
Inputs:
input_batch: Batch container with CPU length tensor.
num_tokens_no_spec_gpu: Destination GPU length tensor.
active_idx_cpu: Active request indices on CPU.
active_idx_gpu: Active request indices on GPU.
n_active: Number of active requests.
device: Target CUDA device.
_pinned_val_buf: Resident pinned int32 staging buffer.
Outputs:
None (updates num_tokens_no_spec_gpu in-place).
"""
src_cpu
=
input_batch
.
num_tokens_no_spec_cpu_tensor
vals
=
_pinned_val_buf
[:
n_active
]
vals
.
copy_
(
src_cpu
.
index_select
(
0
,
active_idx_cpu
))
num_tokens_no_spec_gpu
.
index_copy_
(
0
,
active_idx_gpu
,
vals
.
to
(
device
=
device
,
non_blocking
=
True
),
)
def
copy_num_valid_draft_tokens
(
num_valid_draft_tokens_cpu
:
torch
.
Tensor
,
num_valid_draft_tokens_copy_stream
:
torch
.
cuda
.
Stream
,
num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
,
num_valid_draft_tokens
:
torch
.
Tensor
|
None
,
batch_size
:
int
,
)
->
None
:
"""
Async D2H copy of per-request valid draft counts.
"""
if
num_valid_draft_tokens
is
None
:
return
num_reqs_to_copy
=
min
(
batch_size
,
num_valid_draft_tokens
.
shape
[
0
])
if
num_reqs_to_copy
<=
0
:
return
default_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
num_valid_draft_tokens_copy_stream
):
num_valid_draft_tokens_copy_stream
.
wait_stream
(
default_stream
)
num_valid_draft_tokens_cpu
[:
num_reqs_to_copy
].
copy_
(
num_valid_draft_tokens
[:
num_reqs_to_copy
],
non_blocking
=
True
)
num_valid_draft_tokens_event
.
record
()
vllm/v1/worker/gpu_input_batch.py
View file @
a6be75db
...
...
@@ -127,7 +127,13 @@ class InputBatch:
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self
.
req_prompt_embeds
:
dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
num_tokens_no_spec
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens_no_spec_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
num_tokens_no_spec
=
self
.
num_tokens_no_spec_cpu_tensor
.
numpy
()
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,),
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
a6be75db
...
...
@@ -10,7 +10,7 @@ from collections import defaultdict
from
collections.abc
import
Iterable
,
Iterator
,
Sequence
from
contextlib
import
contextmanager
from
copy
import
copy
,
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
replace
from
functools
import
reduce
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
...
...
@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer_gpu
import
(
NgramProposerGPU
,
copy_num_valid_draft_tokens
,
update_ngram_gpu_tensors_incremental
,
update_scheduler_for_invalid_drafts
,
)
from
vllm.v1.spec_decode.suffix_decoding
import
SuffixDecodingProposer
from
vllm.v1.structured_output.utils
import
apply_grammar_bitmask
from
vllm.v1.utils
import
CpuGpuBuffer
,
record_function_or_nullcontext
...
...
@@ -424,7 +430,7 @@ class GPUModelRunner(
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mi
r
co-batches
# TODO: Support overlapping mic
r
o-batches
# https://github.com/vllm-project/vllm/issues/18019
self
.
broadcast_pp_output
=
(
self
.
parallel_config
.
distributed_executor_backend
==
"external_launcher"
...
...
@@ -493,6 +499,7 @@ class GPUModelRunner(
if
self
.
speculative_config
and
get_pp_group
().
is_last_rank
:
self
.
drafter
:
(
NgramProposer
# noqa: F823
|
NgramProposerGPU
|
SuffixDecodingProposer
|
EagleProposer
|
DraftModelProposer
...
...
@@ -509,6 +516,23 @@ class GPUModelRunner(
device
=
self
.
device
,
runner
=
self
,
)
elif
self
.
speculative_config
.
use_ngram_gpu
():
self
.
drafter
=
NgramProposerGPU
(
self
.
vllm_config
,
self
.
device
,
self
)
self
.
num_tokens_no_spec_gpu
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
token_ids_gpu_tensor
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
_ngram_pinned_idx_buf
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
long
,
pin_memory
=
True
)
self
.
_ngram_pinned_val_buf
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
...
...
@@ -564,7 +588,7 @@ class GPUModelRunner(
)
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
# We need to use the encoder length for encoder-decoer
# We need to use the encoder length for encoder-deco
d
er
# because of KV cache for cross-attention.
max_model_len
=
max
(
self
.
max_model_len
,
self
.
max_encoder_len
),
max_num_batched_tokens
=
self
.
max_num_tokens
,
...
...
@@ -721,6 +745,21 @@ class GPUModelRunner(
# Cached outputs.
self
.
_draft_token_ids
:
list
[
list
[
int
]]
|
torch
.
Tensor
|
None
=
None
# N-gram GPU path: async D2H buffer/event for per-request valid draft counts.
self
.
_num_valid_draft_tokens
:
torch
.
Tensor
|
None
=
None
self
.
_num_valid_draft_tokens_cpu
:
torch
.
Tensor
|
None
=
None
self
.
_num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
|
None
=
None
self
.
_num_valid_draft_tokens_copy_stream
:
torch
.
cuda
.
Stream
|
None
=
None
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
self
.
_num_valid_draft_tokens_cpu
=
torch
.
empty
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
pin_memory
=
self
.
pin_memory
)
self
.
_num_valid_draft_tokens_event
=
torch
.
cuda
.
Event
()
self
.
_num_valid_draft_tokens_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
_draft_token_req_ids
:
list
[
str
]
|
None
=
None
self
.
transfer_event
=
torch
.
Event
()
self
.
sampled_token_ids_pinned_cpu
=
torch
.
empty
(
...
...
@@ -992,6 +1031,13 @@ class GPUModelRunner(
for
req_id
in
unscheduled_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
is_ngram_gpu
=
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
)
if
is_ngram_gpu
:
ngram_gpu_new_reqs
:
list
[
CachedRequestState
]
=
[]
reqs_to_add
:
list
[
CachedRequestState
]
=
[]
# Add new requests to the cached states.
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
...
...
@@ -1054,12 +1100,31 @@ class GPUModelRunner(
self
.
_init_xdrope_positions
(
req_state
)
reqs_to_add
.
append
(
req_state
)
# Track new requests for ngram_gpu full tensor copy
if
is_ngram_gpu
:
ngram_gpu_new_reqs
.
append
(
req_state
)
# Update the states of the running/resumed requests.
is_last_rank
=
get_pp_group
().
is_last_rank
req_data
=
scheduler_output
.
scheduled_cached_reqs
scheduled_spec_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
# Save scheduler-allocated spec lengths before trimming so
# prev_num_draft_len keeps the optimistic count for rejection correction.
original_num_spec_per_req
:
dict
[
str
,
int
]
=
{}
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
for
req_id
,
toks
in
scheduled_spec_tokens
.
items
():
original_num_spec_per_req
[
req_id
]
=
len
(
toks
)
update_scheduler_for_invalid_drafts
(
self
.
_num_valid_draft_tokens_event
,
self
.
_num_valid_draft_tokens_cpu
,
scheduler_output
,
self
.
input_batch
.
req_id_to_index
,
)
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count
=
self
.
_get_valid_sampled_token_count
()
...
...
@@ -1076,13 +1141,13 @@ class GPUModelRunner(
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# fi
r
st step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain
# num_computed_tokens in first step and second step does
n
't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
...
...
@@ -1096,6 +1161,9 @@ class GPUModelRunner(
num_computed_tokens
-=
num_rejected
req_state
.
output_token_ids
.
extend
([
-
1
]
*
num_accepted
)
if
is_ngram_gpu
and
num_accepted
>
0
and
req_index
is
not
None
:
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
+=
num_accepted
# Update the cached states.
req_state
.
num_computed_tokens
=
num_computed_tokens
...
...
@@ -1156,6 +1224,9 @@ class GPUModelRunner(
req_state
.
output_token_ids
=
resumed_token_ids
[
-
num_output_tokens
:]
reqs_to_add
.
append
(
req_state
)
# Track resumed requests for ngram_gpu full tensor copy
if
is_ngram_gpu
:
ngram_gpu_new_reqs
.
append
(
req_state
)
continue
# Update the persistent batch.
...
...
@@ -1176,6 +1247,11 @@ class GPUModelRunner(
# Add spec_token_ids to token_ids_cpu.
self
.
input_batch
.
update_req_spec_token_ids
(
req_state
,
scheduled_spec_tokens
)
# Restore scheduler-side draft count after ngram trimming.
if
original_num_spec_per_req
:
orig
=
original_num_spec_per_req
.
get
(
req_id
,
0
)
if
orig
!=
req_state
.
prev_num_draft_len
:
req_state
.
prev_num_draft_len
=
orig
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
...
...
@@ -1190,6 +1266,18 @@ class GPUModelRunner(
# Refresh batch metadata with any pending updates.
self
.
input_batch
.
refresh_metadata
()
# Incrementally update ngram_gpu tensors after batch is stable
if
is_ngram_gpu
:
update_ngram_gpu_tensors_incremental
(
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
ngram_gpu_new_reqs
,
self
.
device
,
_pinned_idx_buf
=
self
.
_ngram_pinned_idx_buf
,
_pinned_val_buf
=
self
.
_ngram_pinned_val_buf
,
)
def
_update_states_after_model_execute
(
self
,
output_token_ids
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
...
...
@@ -3412,6 +3500,23 @@ class GPUModelRunner(
else
:
logger
.
error
(
"RoutedExpertsCapturer not initialized."
)
# If ngram_gpu is used, we need to copy the scheduler_output to avoid
# the modification has influence on the scheduler_output in engine core process.
# The replace is much faster than deepcopy.
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
num_scheduled_tokens_copy
=
scheduler_output
.
num_scheduled_tokens
.
copy
()
spec_decode_tokens_copy
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
copy
()
)
scheduler_output
=
replace
(
scheduler_output
,
num_scheduled_tokens
=
num_scheduled_tokens_copy
,
scheduled_spec_decode_tokens
=
spec_decode_tokens_copy
,
)
if
scheduler_output
.
preempted_req_ids
and
has_kv_transfer_group
():
get_kv_transfer_group
().
handle_preemptions
(
scheduler_output
.
preempted_req_ids
...
...
@@ -3825,6 +3930,32 @@ class GPUModelRunner(
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
self
.
_draft_token_ids
=
torch
.
zeros
(
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
).
expand
(
len
(
self
.
input_batch
.
req_ids
),
self
.
num_spec_tokens
)
self
.
_copy_draft_token_ids_to_cpu
(
scheduler_output
,
zeros_only
=
True
)
elif
(
spec_config
.
use_ngram_gpu
()
and
not
spec_config
.
disable_padded_drafter_batch
):
assert
isinstance
(
self
.
drafter
,
NgramProposerGPU
)
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
input_fits_in_drafter
:
propose_draft_token_ids
(
sampled_token_ids
)
elif
self
.
valid_sampled_token_count_event
is
not
None
:
assert
spec_decode_common_attn_metadata
is
not
None
next_token_ids
,
valid_sampled_tokens_count
,
_
=
(
self
.
drafter
.
update_token_ids_ngram
(
sampled_token_ids
,
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
self
.
discard_request_mask
.
gpu
,
)
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self
.
_draft_token_ids
=
torch
.
zeros
(
...
...
@@ -4064,6 +4195,52 @@ class GPUModelRunner(
self
.
input_batch
.
token_ids_cpu
,
slot_mappings
=
slot_mappings
,
)
if
isinstance
(
self
.
drafter
,
NgramProposer
):
assert
isinstance
(
sampled_token_ids
,
list
),
(
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids
=
self
.
drafter
.
propose
(
sampled_token_ids
,
self
.
input_batch
.
num_tokens_no_spec
,
self
.
input_batch
.
token_ids_cpu
,
)
elif
spec_config
.
use_ngram_gpu
():
assert
isinstance
(
self
.
drafter
,
NgramProposerGPU
)
(
next_token_ids
,
valid_sampled_tokens_count
,
valid_sampled_token_ids_gpu
,
)
=
self
.
drafter
.
update_token_ids_ngram
(
sampled_token_ids
,
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
self
.
discard_request_mask
.
gpu
,
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
batch_size
=
next_token_ids
.
shape
[
0
]
draft_token_ids
,
num_valid_draft_tokens
=
self
.
drafter
.
propose
(
self
.
num_tokens_no_spec_gpu
[:
batch_size
],
self
.
token_ids_gpu_tensor
[:
batch_size
],
valid_sampled_token_ids_gpu
,
valid_sampled_tokens_count
,
)
# Cache valid draft counts for scheduler-side trimming.
self
.
_num_valid_draft_tokens
=
num_valid_draft_tokens
# Async D2H copy on a dedicated stream.
copy_num_valid_draft_tokens
(
self
.
_num_valid_draft_tokens_cpu
,
self
.
_num_valid_draft_tokens_copy_stream
,
self
.
_num_valid_draft_tokens_event
,
self
.
_num_valid_draft_tokens
,
self
.
input_batch
.
num_reqs
,
)
elif
spec_config
.
method
==
"suffix"
:
assert
isinstance
(
sampled_token_ids
,
list
)
assert
isinstance
(
self
.
drafter
,
SuffixDecodingProposer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment