Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a6be75db
Unverified
Commit
a6be75db
authored
Mar 08, 2026
by
PatchyTIS
Committed by
GitHub
Mar 07, 2026
Browse files
[Core] NGram GPU Implementation compatible with Async Scheduler (#29184)
parent
ee54f9cd
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
940 additions
and
12 deletions
+940
-12
tests/v1/e2e/test_async_scheduling.py
tests/v1/e2e/test_async_scheduling.py
+40
-3
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+28
-0
vllm/compilation/backends.py
vllm/compilation/backends.py
+7
-0
vllm/config/speculative.py
vllm/config/speculative.py
+9
-1
vllm/config/vllm.py
vllm/config/vllm.py
+5
-2
vllm/tool_parsers/hermes_tool_parser.py
vllm/tool_parsers/hermes_tool_parser.py
+2
-0
vllm/v1/spec_decode/ngram_proposer_gpu.py
vllm/v1/spec_decode/ngram_proposer_gpu.py
+660
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+7
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+182
-5
No files found.
tests/v1/e2e/test_async_scheduling.py
View file @
a6be75db
...
...
@@ -98,7 +98,7 @@ def test_without_spec_decoding(
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
16
)
def
test_with_spec_decoding
(
sample_json_schema
,
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_with_
eagle3_
spec_decoding
(
sample_json_schema
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
...
...
@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
)
def
test_with_ngram_gpu_spec_decoding
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test ngram_gpu speculative decoding with different configurations.
This test specifically validates ngram_gpu behavior with various:
- Number of speculative tokens (2-6)
- Prompt lookup window sizes (min/max)
- Async scheduling enabled (as in production)
- Different executors and chunking settings
"""
# Variant with larger speculation window
ngram_gpu_config
=
{
"method"
:
"ngram_gpu"
,
"num_speculative_tokens"
:
3
,
"prompt_lookup_max"
:
3
,
"prompt_lookup_min"
:
2
,
}
# Test configurations covering various scenarios
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs
=
[
(
False
,
"mp"
,
False
,
None
,
False
),
(
False
,
"mp"
,
False
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
False
,
ngram_gpu_config
,
True
),
(
False
,
"mp"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"uni"
,
True
,
ngram_gpu_config
,
False
),
(
True
,
"mp"
,
True
,
ngram_gpu_config
,
True
),
]
# Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
# and ngram_gpu doesn't require a specific draft model
run_tests
(
monkeypatch
,
MODEL
,
test_configs
,
[{}])
@
dynamo_config
.
patch
(
cache_size_limit
=
16
)
def
run_tests
(
monkeypatch
:
pytest
.
MonkeyPatch
,
...
...
@@ -282,11 +318,12 @@ def run_test(
else
dict
(
gpu_memory_utilization
=
0.9
)
)
spec_mml
=
(
spec_config
or
{}).
get
(
"max_model_len"
)
spec_method
=
(
spec_config
or
{}).
get
(
"method"
,
"none"
)
test_config
=
(
f
"executor=
{
executor
}
, preemption=
{
test_preemption
}
, "
f
"async_sched=
{
async_scheduling
}
, "
f
"chunk_prefill=
{
test_prefill_chunking
}
, "
f
"spec_decoding=
{
spec_decoding
}
, spec_mml=
{
spec_mml
}
"
f
"spec_decoding=
{
spec_decoding
}
,
spec_method=
{
spec_method
}
,
spec_mml=
{
spec_mml
}
"
)
print
(
"-"
*
80
)
print
(
f
"---- TESTING
{
test_str
}
:
{
test_config
}
"
)
...
...
@@ -294,7 +331,7 @@ def run_test(
with
VllmRunner
(
model
,
max_model_len
=
512
,
max_model_len
=
4096
,
enable_chunked_prefill
=
test_prefill_chunking
,
# Force prefill chunking
max_num_batched_tokens
=
48
if
test_prefill_chunking
else
None
,
...
...
tests/v1/e2e/test_spec_decode.py
View file @
a6be75db
...
...
@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness(
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
],
ids
=
[
"async"
])
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
20
)
def
test_ngram_gpu_default_with_async_scheduling
(
async_scheduling
:
bool
,
):
"""
Test ngram_gpu speculative decoding (k=3) correctness with and without
async scheduling, validated via GSM8K accuracy.
Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%).
"""
qwen3_model
=
"Qwen/Qwen3-8B"
spec_llm
=
LLM
(
model
=
qwen3_model
,
speculative_config
=
{
"method"
:
"ngram_gpu"
,
"prompt_lookup_max"
:
3
,
"prompt_lookup_min"
:
2
,
"num_speculative_tokens"
:
2
,
},
max_model_len
=
4096
,
async_scheduling
=
async_scheduling
,
)
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
0.8
)
del
spec_llm
cleanup_dist_env_and_memory
()
@
single_gpu_only
@
large_gpu_mark
(
min_gb
=
20
)
def
test_suffix_decoding_acceptance
(
...
...
vllm/compilation/backends.py
View file @
a6be75db
...
...
@@ -907,6 +907,13 @@ class VllmBackend:
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache
=
not
is_compile_cache_enabled
(
self
.
inductor_config
)
# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
is_ngram_gpu_enabled
=
(
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
speculative_config
.
use_ngram_gpu
()
)
disable_cache
=
disable_cache
or
is_ngram_gpu_enabled
if
disable_cache
:
logger
.
info_once
(
"vLLM's torch.compile cache is disabled."
,
scope
=
"local"
)
else
:
...
...
vllm/config/speculative.py
View file @
a6be75db
...
...
@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
"step3p5_mtp"
,
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
]
NgramGPUTypes
=
Literal
[
"ngram_gpu"
]
SpeculativeMethod
=
Literal
[
"ngram"
,
"medusa"
,
...
...
@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[
"draft_model"
,
"suffix"
,
EagleModelTypes
,
NgramGPUTypes
,
]
...
...
@@ -364,6 +366,8 @@ class SpeculativeConfig:
self
.
quantization
=
self
.
target_model_config
.
quantization
elif
self
.
method
in
(
"ngram"
,
"[ngram]"
):
self
.
model
=
"ngram"
elif
self
.
method
==
"ngram_gpu"
:
self
.
model
=
"ngram_gpu"
elif
self
.
method
==
"suffix"
:
self
.
model
=
"suffix"
elif
self
.
method
==
"extract_hidden_states"
:
...
...
@@ -374,8 +378,9 @@ class SpeculativeConfig:
)
if
self
.
method
in
(
"ngram"
,
"[ngram]"
):
# Unified to "ngram" internally
self
.
method
=
"ngram"
if
self
.
method
in
(
"ngram"
,
"ngram_gpu"
):
# Set default values if not provided
if
self
.
prompt_lookup_min
is
None
and
self
.
prompt_lookup_max
is
None
:
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
...
...
@@ -832,6 +837,9 @@ class SpeculativeConfig:
def
uses_extract_hidden_states
(
self
)
->
bool
:
return
self
.
method
==
"extract_hidden_states"
def
use_ngram_gpu
(
self
)
->
bool
:
return
self
.
method
==
"ngram_gpu"
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
model
=
(
...
...
vllm/config/vllm.py
View file @
a6be75db
...
...
@@ -41,7 +41,7 @@ from .offload import OffloadConfig
from
.parallel
import
ParallelConfig
from
.profiler
import
ProfilerConfig
from
.scheduler
import
SchedulerConfig
from
.speculative
import
EagleModelTypes
,
SpeculativeConfig
from
.speculative
import
EagleModelTypes
,
NgramGPUTypes
,
SpeculativeConfig
from
.structured_outputs
import
StructuredOutputsConfig
from
.utils
import
SupportsHash
,
config
,
replace
from
.weight_transfer
import
WeightTransferConfig
...
...
@@ -696,11 +696,13 @@ class VllmConfig:
if
self
.
speculative_config
is
not
None
:
if
(
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
and
self
.
speculative_config
.
method
not
in
get_args
(
NgramGPUTypes
)
and
self
.
speculative_config
.
method
!=
"draft_model"
):
raise
ValueError
(
"Currently, async scheduling is only supported "
"with EAGLE/MTP/Draft Model kind of speculative decoding."
"with EAGLE/MTP/Draft Model/NGram GPU kind of "
"speculative decoding"
)
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
raise
ValueError
(
...
...
@@ -718,6 +720,7 @@ class VllmConfig:
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
and
self
.
speculative_config
.
method
not
in
get_args
(
NgramGPUTypes
)
):
logger
.
warning_once
(
"Async scheduling not supported with %s-based "
...
...
vllm/tool_parsers/hermes_tool_parser.py
View file @
a6be75db
...
...
@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser):
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
assert
current_tool_call
is
not
None
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
logger
.
debug
(
"diffing old arguments: %s"
,
prev_arguments
)
...
...
@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser):
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
assert
isinstance
(
current_tool_call
,
dict
)
if
self
.
current_tool_id
==
len
(
self
.
prev_tool_call_arr
)
-
1
:
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
current_tool_call
else
:
...
...
vllm/v1/spec_decode/ngram_proposer_gpu.py
0 → 100644
View file @
a6be75db
This diff is collapsed.
Click to expand it.
vllm/v1/worker/gpu_input_batch.py
View file @
a6be75db
...
...
@@ -127,7 +127,13 @@ class InputBatch:
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self
.
req_prompt_embeds
:
dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
num_tokens_no_spec
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens_no_spec_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
num_tokens_no_spec
=
self
.
num_tokens_no_spec_cpu_tensor
.
numpy
()
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,),
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
a6be75db
...
...
@@ -10,7 +10,7 @@ from collections import defaultdict
from
collections.abc
import
Iterable
,
Iterator
,
Sequence
from
contextlib
import
contextmanager
from
copy
import
copy
,
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
replace
from
functools
import
reduce
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
...
...
@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer_gpu
import
(
NgramProposerGPU
,
copy_num_valid_draft_tokens
,
update_ngram_gpu_tensors_incremental
,
update_scheduler_for_invalid_drafts
,
)
from
vllm.v1.spec_decode.suffix_decoding
import
SuffixDecodingProposer
from
vllm.v1.structured_output.utils
import
apply_grammar_bitmask
from
vllm.v1.utils
import
CpuGpuBuffer
,
record_function_or_nullcontext
...
...
@@ -424,7 +430,7 @@ class GPUModelRunner(
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mi
r
co-batches
# TODO: Support overlapping mic
r
o-batches
# https://github.com/vllm-project/vllm/issues/18019
self
.
broadcast_pp_output
=
(
self
.
parallel_config
.
distributed_executor_backend
==
"external_launcher"
...
...
@@ -493,6 +499,7 @@ class GPUModelRunner(
if
self
.
speculative_config
and
get_pp_group
().
is_last_rank
:
self
.
drafter
:
(
NgramProposer
# noqa: F823
|
NgramProposerGPU
|
SuffixDecodingProposer
|
EagleProposer
|
DraftModelProposer
...
...
@@ -509,6 +516,23 @@ class GPUModelRunner(
device
=
self
.
device
,
runner
=
self
,
)
elif
self
.
speculative_config
.
use_ngram_gpu
():
self
.
drafter
=
NgramProposerGPU
(
self
.
vllm_config
,
self
.
device
,
self
)
self
.
num_tokens_no_spec_gpu
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
token_ids_gpu_tensor
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
_ngram_pinned_idx_buf
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
long
,
pin_memory
=
True
)
self
.
_ngram_pinned_val_buf
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
...
...
@@ -564,7 +588,7 @@ class GPUModelRunner(
)
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
# We need to use the encoder length for encoder-decoer
# We need to use the encoder length for encoder-deco
d
er
# because of KV cache for cross-attention.
max_model_len
=
max
(
self
.
max_model_len
,
self
.
max_encoder_len
),
max_num_batched_tokens
=
self
.
max_num_tokens
,
...
...
@@ -721,6 +745,21 @@ class GPUModelRunner(
# Cached outputs.
self
.
_draft_token_ids
:
list
[
list
[
int
]]
|
torch
.
Tensor
|
None
=
None
# N-gram GPU path: async D2H buffer/event for per-request valid draft counts.
self
.
_num_valid_draft_tokens
:
torch
.
Tensor
|
None
=
None
self
.
_num_valid_draft_tokens_cpu
:
torch
.
Tensor
|
None
=
None
self
.
_num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
|
None
=
None
self
.
_num_valid_draft_tokens_copy_stream
:
torch
.
cuda
.
Stream
|
None
=
None
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
self
.
_num_valid_draft_tokens_cpu
=
torch
.
empty
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
pin_memory
=
self
.
pin_memory
)
self
.
_num_valid_draft_tokens_event
=
torch
.
cuda
.
Event
()
self
.
_num_valid_draft_tokens_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
_draft_token_req_ids
:
list
[
str
]
|
None
=
None
self
.
transfer_event
=
torch
.
Event
()
self
.
sampled_token_ids_pinned_cpu
=
torch
.
empty
(
...
...
@@ -992,6 +1031,13 @@ class GPUModelRunner(
for
req_id
in
unscheduled_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
is_ngram_gpu
=
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
)
if
is_ngram_gpu
:
ngram_gpu_new_reqs
:
list
[
CachedRequestState
]
=
[]
reqs_to_add
:
list
[
CachedRequestState
]
=
[]
# Add new requests to the cached states.
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
...
...
@@ -1054,12 +1100,31 @@ class GPUModelRunner(
self
.
_init_xdrope_positions
(
req_state
)
reqs_to_add
.
append
(
req_state
)
# Track new requests for ngram_gpu full tensor copy
if
is_ngram_gpu
:
ngram_gpu_new_reqs
.
append
(
req_state
)
# Update the states of the running/resumed requests.
is_last_rank
=
get_pp_group
().
is_last_rank
req_data
=
scheduler_output
.
scheduled_cached_reqs
scheduled_spec_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
# Save scheduler-allocated spec lengths before trimming so
# prev_num_draft_len keeps the optimistic count for rejection correction.
original_num_spec_per_req
:
dict
[
str
,
int
]
=
{}
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
for
req_id
,
toks
in
scheduled_spec_tokens
.
items
():
original_num_spec_per_req
[
req_id
]
=
len
(
toks
)
update_scheduler_for_invalid_drafts
(
self
.
_num_valid_draft_tokens_event
,
self
.
_num_valid_draft_tokens_cpu
,
scheduler_output
,
self
.
input_batch
.
req_id_to_index
,
)
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count
=
self
.
_get_valid_sampled_token_count
()
...
...
@@ -1076,13 +1141,13 @@ class GPUModelRunner(
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# fi
r
st step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain
# num_computed_tokens in first step and second step does
n
't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
...
...
@@ -1096,6 +1161,9 @@ class GPUModelRunner(
num_computed_tokens
-=
num_rejected
req_state
.
output_token_ids
.
extend
([
-
1
]
*
num_accepted
)
if
is_ngram_gpu
and
num_accepted
>
0
and
req_index
is
not
None
:
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
+=
num_accepted
# Update the cached states.
req_state
.
num_computed_tokens
=
num_computed_tokens
...
...
@@ -1156,6 +1224,9 @@ class GPUModelRunner(
req_state
.
output_token_ids
=
resumed_token_ids
[
-
num_output_tokens
:]
reqs_to_add
.
append
(
req_state
)
# Track resumed requests for ngram_gpu full tensor copy
if
is_ngram_gpu
:
ngram_gpu_new_reqs
.
append
(
req_state
)
continue
# Update the persistent batch.
...
...
@@ -1176,6 +1247,11 @@ class GPUModelRunner(
# Add spec_token_ids to token_ids_cpu.
self
.
input_batch
.
update_req_spec_token_ids
(
req_state
,
scheduled_spec_tokens
)
# Restore scheduler-side draft count after ngram trimming.
if
original_num_spec_per_req
:
orig
=
original_num_spec_per_req
.
get
(
req_id
,
0
)
if
orig
!=
req_state
.
prev_num_draft_len
:
req_state
.
prev_num_draft_len
=
orig
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
...
...
@@ -1190,6 +1266,18 @@ class GPUModelRunner(
# Refresh batch metadata with any pending updates.
self
.
input_batch
.
refresh_metadata
()
# Incrementally update ngram_gpu tensors after batch is stable
if
is_ngram_gpu
:
update_ngram_gpu_tensors_incremental
(
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
ngram_gpu_new_reqs
,
self
.
device
,
_pinned_idx_buf
=
self
.
_ngram_pinned_idx_buf
,
_pinned_val_buf
=
self
.
_ngram_pinned_val_buf
,
)
def
_update_states_after_model_execute
(
self
,
output_token_ids
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
...
...
@@ -3412,6 +3500,23 @@ class GPUModelRunner(
else
:
logger
.
error
(
"RoutedExpertsCapturer not initialized."
)
# If ngram_gpu is used, we need to copy the scheduler_output to avoid
# the modification has influence on the scheduler_output in engine core process.
# The replace is much faster than deepcopy.
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_ngram_gpu
()
):
num_scheduled_tokens_copy
=
scheduler_output
.
num_scheduled_tokens
.
copy
()
spec_decode_tokens_copy
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
copy
()
)
scheduler_output
=
replace
(
scheduler_output
,
num_scheduled_tokens
=
num_scheduled_tokens_copy
,
scheduled_spec_decode_tokens
=
spec_decode_tokens_copy
,
)
if
scheduler_output
.
preempted_req_ids
and
has_kv_transfer_group
():
get_kv_transfer_group
().
handle_preemptions
(
scheduler_output
.
preempted_req_ids
...
...
@@ -3825,6 +3930,32 @@ class GPUModelRunner(
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
self
.
_draft_token_ids
=
torch
.
zeros
(
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
).
expand
(
len
(
self
.
input_batch
.
req_ids
),
self
.
num_spec_tokens
)
self
.
_copy_draft_token_ids_to_cpu
(
scheduler_output
,
zeros_only
=
True
)
elif
(
spec_config
.
use_ngram_gpu
()
and
not
spec_config
.
disable_padded_drafter_batch
):
assert
isinstance
(
self
.
drafter
,
NgramProposerGPU
)
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
input_fits_in_drafter
:
propose_draft_token_ids
(
sampled_token_ids
)
elif
self
.
valid_sampled_token_count_event
is
not
None
:
assert
spec_decode_common_attn_metadata
is
not
None
next_token_ids
,
valid_sampled_tokens_count
,
_
=
(
self
.
drafter
.
update_token_ids_ngram
(
sampled_token_ids
,
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
self
.
discard_request_mask
.
gpu
,
)
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self
.
_draft_token_ids
=
torch
.
zeros
(
...
...
@@ -4064,6 +4195,52 @@ class GPUModelRunner(
self
.
input_batch
.
token_ids_cpu
,
slot_mappings
=
slot_mappings
,
)
if
isinstance
(
self
.
drafter
,
NgramProposer
):
assert
isinstance
(
sampled_token_ids
,
list
),
(
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids
=
self
.
drafter
.
propose
(
sampled_token_ids
,
self
.
input_batch
.
num_tokens_no_spec
,
self
.
input_batch
.
token_ids_cpu
,
)
elif
spec_config
.
use_ngram_gpu
():
assert
isinstance
(
self
.
drafter
,
NgramProposerGPU
)
(
next_token_ids
,
valid_sampled_tokens_count
,
valid_sampled_token_ids_gpu
,
)
=
self
.
drafter
.
update_token_ids_ngram
(
sampled_token_ids
,
self
.
input_batch
,
self
.
token_ids_gpu_tensor
,
self
.
num_tokens_no_spec_gpu
,
self
.
discard_request_mask
.
gpu
,
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
batch_size
=
next_token_ids
.
shape
[
0
]
draft_token_ids
,
num_valid_draft_tokens
=
self
.
drafter
.
propose
(
self
.
num_tokens_no_spec_gpu
[:
batch_size
],
self
.
token_ids_gpu_tensor
[:
batch_size
],
valid_sampled_token_ids_gpu
,
valid_sampled_tokens_count
,
)
# Cache valid draft counts for scheduler-side trimming.
self
.
_num_valid_draft_tokens
=
num_valid_draft_tokens
# Async D2H copy on a dedicated stream.
copy_num_valid_draft_tokens
(
self
.
_num_valid_draft_tokens_cpu
,
self
.
_num_valid_draft_tokens_copy_stream
,
self
.
_num_valid_draft_tokens_event
,
self
.
_num_valid_draft_tokens
,
self
.
input_batch
.
num_reqs
,
)
elif
spec_config
.
method
==
"suffix"
:
assert
isinstance
(
sampled_token_ids
,
list
)
assert
isinstance
(
self
.
drafter
,
SuffixDecodingProposer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment