Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4a5299c9
Unverified
Commit
4a5299c9
authored
Jan 19, 2026
by
Tomas Ruiz
Committed by
GitHub
Jan 19, 2026
Browse files
feat: spec decode with draft models (#24322)
Signed-off-by:
Tomas Ruiz
<
tomas.ruiz.te@gmail.com
>
parent
73f2a81c
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
895 additions
and
113 deletions
+895
-113
examples/offline_inference/spec_decode.py
examples/offline_inference/spec_decode.py
+17
-2
examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py
..._serving/disaggregated_serving/moriio_toy_proxy_server.py
+1
-1
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+306
-9
tests/v1/worker/test_utils.py
tests/v1/worker/test_utils.py
+35
-0
vllm/benchmarks/datasets.py
vllm/benchmarks/datasets.py
+10
-9
vllm/benchmarks/lib/ready_checker.py
vllm/benchmarks/lib/ready_checker.py
+6
-0
vllm/config/parallel.py
vllm/config/parallel.py
+4
-0
vllm/config/speculative.py
vllm/config/speculative.py
+33
-6
vllm/config/vllm.py
vllm/config/vllm.py
+23
-13
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-15
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+7
-2
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+2
-2
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+2
-2
vllm/model_executor/model_loader/tensorizer_loader.py
vllm/model_executor/model_loader/tensorizer_loader.py
+4
-3
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+11
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+32
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+2
-0
vllm/v1/spec_decode/draft_model.py
vllm/v1/spec_decode/draft_model.py
+271
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+98
-38
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+31
-10
No files found.
examples/offline_inference/spec_decode.py
View file @
4a5299c9
...
...
@@ -54,7 +54,7 @@ def parse_args():
"--method"
,
type
=
str
,
default
=
"eagle"
,
choices
=
[
"ngram"
,
"eagle"
,
"eagle3"
,
"mtp"
],
choices
=
[
"ngram"
,
"eagle"
,
"eagle3"
,
"mtp"
,
"draft_model"
],
)
parser
.
add_argument
(
"--num-spec-tokens"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--prompt-lookup-max"
,
type
=
int
,
default
=
5
)
...
...
@@ -70,7 +70,11 @@ def parse_args():
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--model-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--eagle-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--draft-model"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--custom-mm-prompts"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--gpu-memory-utilization"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--disable-padded-drafter-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--max-num-seqs"
,
type
=
int
,
default
=
None
)
return
parser
.
parse_args
()
...
...
@@ -111,6 +115,7 @@ def main(args):
"method"
:
args
.
method
,
"model"
:
eagle_dir
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"disable_padded_drafter_batch"
:
args
.
disable_padded_drafter_batch
,
}
elif
args
.
method
==
"ngram"
:
speculative_config
=
{
...
...
@@ -119,6 +124,15 @@ def main(args):
"prompt_lookup_max"
:
args
.
prompt_lookup_max
,
"prompt_lookup_min"
:
args
.
prompt_lookup_min
,
}
elif
args
.
method
==
"draft_model"
:
assert
args
.
draft_model
is
not
None
and
args
.
draft_model
!=
""
speculative_config
=
{
"method"
:
args
.
method
,
"model"
:
args
.
draft_model
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"enforce_eager"
:
args
.
enforce_eager
,
"max_model_len"
:
args
.
max_model_len
,
}
elif
args
.
method
==
"mtp"
:
speculative_config
=
{
"method"
:
"mtp"
,
...
...
@@ -133,12 +147,13 @@ def main(args):
tensor_parallel_size
=
args
.
tp
,
enable_chunked_prefill
=
args
.
enable_chunked_prefill
,
enforce_eager
=
args
.
enforce_eager
,
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
args
.
gpu_memory_utilization
,
speculative_config
=
speculative_config
,
disable_log_stats
=
False
,
max_model_len
=
args
.
max_model_len
,
limit_mm_per_prompt
=
{
"image"
:
5
},
disable_chunked_mm_input
=
True
,
max_num_seqs
=
args
.
max_num_seqs
,
)
sampling_params
=
SamplingParams
(
temperature
=
args
.
temp
,
max_tokens
=
args
.
output_len
)
...
...
examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py
View file @
4a5299c9
...
...
@@ -4,13 +4,13 @@ import asyncio
import
copy
import
logging
import
os
import
re
import
socket
import
threading
import
uuid
import
aiohttp
import
msgpack
import
regex
as
re
import
zmq
from
quart
import
Quart
,
make_response
,
request
...
...
tests/v1/e2e/test_spec_decode.py
View file @
4a5299c9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
typing
import
Any
import
pytest
...
...
@@ -10,32 +12,45 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
from
vllm.assets.image
import
VLM_IMAGES_DIR
from
vllm.benchmarks.datasets
import
InstructCoderDataset
from
vllm.config.vllm
import
VllmConfig
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.v1.metrics.reader
import
Metric
from
vllm.v1.spec_decode.draft_model
import
(
create_vllm_config_for_draft_model
,
merge_toks_kernel
,
)
MTP_SIMILARITY_RATE
=
0.8
def
_skip_if_insufficient_gpus_for_tp
(
tp_size
:
int
):
"""Skip test if available GPUs < tp_size on ROCm."""
if
current_platform
.
is_rocm
():
available_gpus
=
torch
.
cuda
.
device_count
()
if
available_gpus
<
tp_size
:
pytest
.
skip
(
f
"Test requires
{
tp_size
}
GPUs, but only
{
available_gpus
}
available"
)
available_gpus
=
torch
.
cuda
.
device_count
()
if
available_gpus
<
tp_size
:
pytest
.
skip
(
f
"Test requires
{
tp_size
}
GPUs, but only
{
available_gpus
}
available"
)
def
get_test_prompts
(
mm_enabled
:
bool
):
Messages
=
list
[
dict
[
str
,
Any
]]
def
get_test_prompts
(
mm_enabled
:
bool
,
quiet
:
bool
=
False
,
num_prompts
:
int
=
100
)
->
list
[
Messages
]:
prompt_types
=
[
"repeat"
,
"sentence"
]
if
mm_enabled
:
prompt_types
.
append
(
"mm"
)
num_prompts
=
100
prompts
=
[]
random
.
seed
(
0
)
random_prompt_type_choices
=
random
.
choices
(
prompt_types
,
k
=
num_prompts
)
print
(
f
"Prompt types:
{
random_prompt_type_choices
}
"
)
if
not
quiet
:
print
(
f
"Prompt types:
{
random_prompt_type_choices
}
"
)
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
...
...
@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool):
return
prompts
def
get_instruct_coder_messages
(
n
:
int
)
->
list
[
Messages
]:
dataset
=
InstructCoderDataset
(
dataset_path
=
"likaixin/InstructCoder"
,
dataset_split
=
"train"
)
prompts
:
Iterable
[
str
]
=
dataset
.
sample_prompts
(
n
=
n
)
return
[[{
"role"
:
"user"
,
"content"
:
prompt
}]
for
prompt
in
prompts
]
@
pytest
.
fixture
def
sampling_config
():
return
greedy_sampling
()
def
greedy_sampling
()
->
SamplingParams
:
return
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
,
ignore_eos
=
False
)
def
stochastic_sampling
()
->
SamplingParams
:
return
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
10
,
ignore_eos
=
False
)
@
pytest
.
fixture
def
model_name
():
return
"meta-llama/Llama-3.1-8B-Instruct"
...
...
@@ -583,3 +614,269 @@ def test_mtp_correctness(
del
spec_llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
@
dataclass
class
ArgsTest
:
target_model
:
str
draft_model
:
str
sampling_config
:
SamplingParams
num_speculative_tokens
:
int
expected_acceptance_rate
:
float
expected_acceptance_len
:
float
# Defaults
target_tensor_parallel_size
:
int
=
1
draft_tensor_parallel_size
:
int
=
1
max_model_len
:
int
=
1024
gpu_memory_utilization
:
float
=
0.5
dataset
:
str
=
"test_prompts"
num_prompts
:
int
=
100
cases
=
[
# Same model for draft and target, greedy sampling.
ArgsTest
(
target_model
=
"Qwen/Qwen3-0.6B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
sampling_config
=
greedy_sampling
(),
num_speculative_tokens
=
3
,
# K
expected_acceptance_len
=
3
+
1
,
# K + 1
expected_acceptance_rate
=
1.0
,
),
# Smaller draft model, stochastic sampling.
ArgsTest
(
target_model
=
"Qwen/Qwen3-1.7B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
sampling_config
=
stochastic_sampling
(),
num_speculative_tokens
=
3
,
expected_acceptance_len
=
2.8
+
1
,
expected_acceptance_rate
=
0.9
,
),
]
@
pytest
.
mark
.
parametrize
(
"args"
,
cases
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_draft_model_correctness
(
args
:
ArgsTest
,
enforce_eager
:
bool
):
assert_draft_model_correctness
(
args
,
enforce_eager
)
def
test_draft_model_realistic_example
():
args
=
ArgsTest
(
target_model
=
"Qwen/Qwen3-1.7B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
dataset
=
"likaixin/InstructCoder"
,
num_speculative_tokens
=
3
,
sampling_config
=
greedy_sampling
(),
# values below are not derived, but just prevent a regression
expected_acceptance_len
=
2.8
,
expected_acceptance_rate
=
0.55
,
)
assert_draft_model_correctness
(
args
,
enforce_eager
=
False
)
@
pytest
.
mark
.
parametrize
(
"models"
,
[
# target_model, draft_model
(
"Qwen/Qwen3-1.7B-FP8"
,
"Qwen/Qwen3-0.6B"
),
# target quantized
(
"Qwen/Qwen3-1.7B"
,
"Qwen/Qwen3-0.6B-FP8"
),
# draft quantized
],
ids
=
[
"target_quantized"
,
"draft_quantized"
],
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_draft_model_quantization
(
models
:
tuple
[
str
,
str
],
enforce_eager
:
bool
):
tgt_model
,
draft_model
=
models
sd_case
=
ArgsTest
(
target_model
=
tgt_model
,
draft_model
=
draft_model
,
**
some_high_acceptance_metrics
(),
)
assert_draft_model_correctness
(
sd_case
,
enforce_eager
)
def
test_draft_model_tensor_parallelism
():
"""Ensure spec decode works when running with TP > 1."""
_skip_if_insufficient_gpus_for_tp
(
2
)
sd_case
=
ArgsTest
(
target_model
=
"Qwen/Qwen3-1.7B"
,
target_tensor_parallel_size
=
2
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
draft_tensor_parallel_size
=
2
,
**
some_high_acceptance_metrics
(),
)
assert_draft_model_correctness
(
sd_case
,
enforce_eager
=
False
)
def
test_draft_model_engine_args_tensor_parallelism
():
"""Ensure the vllm_config for the draft model is created correctly,
and independently of the target model (quantization, TP, etc.)"""
_skip_if_insufficient_gpus_for_tp
(
2
)
engine_args
=
EngineArgs
(
model
=
"Qwen/Qwen3-1.7B-FP8"
,
# <<< tgt quantized
tensor_parallel_size
=
2
,
speculative_config
=
{
"model"
:
"Qwen/Qwen3-0.6B"
,
# <<< draft not quantized
"method"
:
"draft_model"
,
"num_speculative_tokens"
:
3
,
"draft_tensor_parallel_size"
:
1
,
# <<< valid arg name
},
)
tgt_vllm_config
:
VllmConfig
=
engine_args
.
create_engine_config
()
assert
tgt_vllm_config
.
parallel_config
.
tensor_parallel_size
==
2
assert
tgt_vllm_config
.
quant_config
.
get_name
()
==
"fp8"
draft_vllm_config
:
VllmConfig
=
create_vllm_config_for_draft_model
(
tgt_vllm_config
)
assert
draft_vllm_config
.
parallel_config
.
tensor_parallel_size
==
1
assert
draft_vllm_config
.
quant_config
is
None
def
test_draft_model_engine_args_rejects_invalid_tp_argname
():
"""The user should pass "draft_tensor_parallel_size" rather than
"tensor_parallel_size". We enforce this with validation."""
engine_args
=
EngineArgs
(
model
=
"Qwen/Qwen3-1.7B"
,
tensor_parallel_size
=
1
,
speculative_config
=
{
"model"
:
"Qwen/Qwen3-0.6B"
,
"method"
:
"draft_model"
,
"num_speculative_tokens"
:
3
,
"tensor_parallel_size"
:
1
,
# <<< invalid arg name
},
)
with
pytest
.
raises
(
ValueError
):
engine_args
.
create_engine_config
()
def
assert_draft_model_correctness
(
args
:
ArgsTest
,
enforce_eager
:
bool
):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts
:
list
[
Messages
]
=
get_messages
(
dataset
=
args
.
dataset
,
n
=
args
.
num_prompts
)
spec_llm
=
LLM
(
model
=
args
.
target_model
,
speculative_config
=
{
"model"
:
args
.
draft_model
,
"method"
:
"draft_model"
,
"num_speculative_tokens"
:
args
.
num_speculative_tokens
,
"max_model_len"
:
args
.
max_model_len
,
"enforce_eager"
:
enforce_eager
,
"draft_tensor_parallel_size"
:
args
.
draft_tensor_parallel_size
,
"max_num_seqs"
:
100
,
# limit cudagraph capture runtime
},
max_model_len
=
args
.
max_model_len
,
gpu_memory_utilization
=
args
.
gpu_memory_utilization
,
tensor_parallel_size
=
args
.
target_tensor_parallel_size
,
enforce_eager
=
enforce_eager
,
disable_log_stats
=
False
,
# enables get_metrics()
)
# we don't check the outputs, only check the metrics
spec_llm
.
chat
(
test_prompts
,
args
.
sampling_config
)
metrics
=
spec_llm
.
get_metrics
()
acceptance_rate
:
float
=
compute_acceptance_rate
(
metrics
)
acceptance_len
:
float
=
compute_acceptance_len
(
metrics
)
del
spec_llm
# CLEANUP
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
print
(
f
"spec-decode: target=
{
args
.
target_model
}
, draft=
{
args
.
draft_model
}
, "
f
"temperature=
{
args
.
sampling_config
.
temperature
:.
2
f
}
, "
f
"acceptance_rate=
{
acceptance_rate
:.
2
f
}
, "
f
"acceptance_len=
{
acceptance_len
:.
2
f
}
, "
)
assert
acceptance_rate
>=
args
.
expected_acceptance_rate
assert
acceptance_len
>=
args
.
expected_acceptance_len
def
get_messages
(
dataset
:
str
,
n
:
int
)
->
list
[
Messages
]:
if
dataset
==
"test_prompts"
:
return
get_test_prompts
(
mm_enabled
=
False
,
quiet
=
True
,
num_prompts
=
n
)
elif
dataset
==
"likaixin/InstructCoder"
:
return
get_instruct_coder_messages
(
n
=
n
)
else
:
raise
NotImplementedError
(
f
"Dataset '
{
dataset
}
' not implemented"
)
def
some_high_acceptance_metrics
()
->
dict
:
return
{
"sampling_config"
:
greedy_sampling
(),
"num_speculative_tokens"
:
3
,
"expected_acceptance_len"
:
2.90
+
1
,
"expected_acceptance_rate"
:
0.90
,
}
def
test_merge_toks_kernel
():
device
=
"cuda"
merged_len
=
5
+
2
# len(target_toks) = 5, batch_size = 2
merged
=
torch
.
full
((
merged_len
,),
-
100
,
device
=
device
)
# -100 is arbitrary
is_rejected_tok
=
torch
.
full
((
merged_len
,),
True
,
device
=
device
)
grid
=
(
2
,)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
],
device
=
device
),
next_toks_ptr
=
torch
.
tensor
([
3
,
2
],
device
=
device
),
query_start_locs_ptr
=
torch
.
tensor
([
0
,
3
],
device
=
device
),
query_end_locs_ptr
=
torch
.
tensor
([
2
,
4
],
device
=
device
),
out_ptr_merged_toks
=
merged
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
5
,
rejected_tok_fill
=-
1
,
)
expected_merged
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
1
,
2
],
device
=
device
)
assert
torch
.
allclose
(
merged
,
expected_merged
)
expected_rejected_toks
=
torch
.
tensor
([
False
]
*
merged_len
,
device
=
device
)
assert
torch
.
allclose
(
is_rejected_tok
,
expected_rejected_toks
)
def
test_merge_toks_kernel_with_rejected_tokens
():
device
=
"cuda"
merged_size
=
9
+
2
# len(target_toks) = 9, batch_size = 2
merged
=
torch
.
full
((
merged_size
,),
-
100
,
device
=
device
)
is_rejected_tok
=
torch
.
full
((
merged_size
,),
True
,
device
=
device
)
grid
=
(
2
,)
merge_toks_kernel
[
grid
](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr
=
torch
.
tensor
([
0
,
1
,
2
,
13
,
14
,
15
,
0
,
1
,
22
],
device
=
device
),
next_toks_ptr
=
torch
.
tensor
([
3
,
2
],
device
=
device
),
query_start_locs_ptr
=
torch
.
tensor
([
0
,
6
],
device
=
device
),
query_end_locs_ptr
=
torch
.
tensor
([
2
,
7
],
device
=
device
),
out_ptr_merged_toks
=
merged
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
9
,
rejected_tok_fill
=-
1
,
)
expected_merged
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
-
1
,
-
1
,
-
1
,
0
,
1
,
2
,
-
1
],
device
=
device
)
assert
torch
.
allclose
(
merged
,
expected_merged
)
expected_rejected_toks
=
torch
.
tensor
(
[
False
,
False
,
False
,
False
,
True
,
True
,
True
,
False
,
False
,
False
,
True
],
device
=
device
,
)
assert
torch
.
allclose
(
is_rejected_tok
,
expected_rejected_toks
)
def
compute_acceptance_rate
(
metrics
:
list
[
Metric
])
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_draft_toks
=
name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
# type: ignore
if
n_draft_toks
==
0
:
return
float
(
"nan"
)
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
return
n_accepted_toks
/
n_draft_toks
def
compute_acceptance_len
(
metrics
:
list
[
Metric
])
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_drafts
=
name2metric
[
"vllm:spec_decode_num_drafts"
].
value
# type: ignore
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
if
n_drafts
==
0
:
return
1
return
1
+
(
n_accepted_toks
/
n_drafts
)
tests/v1/worker/test_utils.py
View file @
4a5299c9
...
...
@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
assert
runner_kv_caches
[
0
]
is
kv_cache
[
"model.layers.20.attn"
]
assert
runner_kv_caches
[
1
]
is
kv_cache
[
"model.layers.28.attn"
]
def
test_bind_kv_cache_draft_model
(
default_vllm_config
):
from
vllm.attention.layer
import
Attention
layer_names
=
[
"model.layers.0.attn"
,
"model.layers.1.attn"
,
"draft_model.layers.0.attn"
,
"draft_model.layers.1.attn"
,
]
ctx
=
{
layer_name
:
Attention
(
32
,
128
,
0.1
,
prefix
=
layer_name
)
for
layer_name
in
layer_names
}
kv_cache
=
{
layer_name
:
torch
.
zeros
((
1
,))
for
layer_name
in
layer_names
}
runner_kv_caches
:
list
[
torch
.
Tensor
]
=
[]
bind_kv_cache
(
kv_cache
,
ctx
,
runner_kv_caches
)
assert
ctx
[
"model.layers.0.attn"
].
kv_cache
[
0
]
is
kv_cache
[
"model.layers.0.attn"
]
assert
ctx
[
"model.layers.1.attn"
].
kv_cache
[
0
]
is
kv_cache
[
"model.layers.1.attn"
]
assert
(
ctx
[
"draft_model.layers.0.attn"
].
kv_cache
[
0
]
is
kv_cache
[
"draft_model.layers.0.attn"
]
)
assert
(
ctx
[
"draft_model.layers.1.attn"
].
kv_cache
[
0
]
is
kv_cache
[
"draft_model.layers.1.attn"
]
)
# caches are ordered by layer_index, interleaving target and draft model
assert
runner_kv_caches
[
0
]
is
kv_cache
[
"model.layers.0.attn"
]
assert
runner_kv_caches
[
1
]
is
kv_cache
[
"draft_model.layers.0.attn"
]
assert
runner_kv_caches
[
2
]
is
kv_cache
[
"model.layers.1.attn"
]
assert
runner_kv_caches
[
3
]
is
kv_cache
[
"draft_model.layers.1.attn"
]
vllm/benchmarks/datasets.py
View file @
4a5299c9
...
...
@@ -2593,17 +2593,10 @@ class InstructCoderDataset(HuggingFaceDataset):
request_id_prefix
:
str
=
""
,
no_oversample
:
bool
=
False
,
**
kwargs
,
)
->
list
:
)
->
list
[
SampleRequest
]
:
output_len
=
output_len
if
output_len
is
not
None
else
self
.
DEFAULT_OUTPUT_LEN
sampled_requests
=
[]
for
i
,
item
in
enumerate
(
self
.
data
):
if
len
(
sampled_requests
)
>=
num_requests
:
break
prompt
=
(
f
"
{
item
[
'input'
]
}
\n\n
{
item
[
'instruction'
]
}
Just output "
"the code, do not include any explanation."
)
for
i
,
prompt
in
enumerate
(
self
.
sample_prompts
(
n
=
num_requests
)):
# apply template
if
not
skip_chat_template
:
prompt
=
tokenizer
.
apply_chat_template
(
...
...
@@ -2626,6 +2619,14 @@ class InstructCoderDataset(HuggingFaceDataset):
)
return
sampled_requests
def
sample_prompts
(
self
,
n
:
int
)
->
Iterator
[
str
]:
for
item
in
self
.
data
.
take
(
n
):
prompt
=
(
f
"
{
item
[
'input'
]
}
\n\n
{
item
[
'instruction'
]
}
Just output "
"the code, do not include any explanation."
)
yield
prompt
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
...
...
vllm/benchmarks/lib/ready_checker.py
View file @
4a5299c9
...
...
@@ -8,8 +8,12 @@ import time
import
aiohttp
from
tqdm.asyncio
import
tqdm
from
vllm.logger
import
init_logger
from
.endpoint_request_func
import
RequestFunc
,
RequestFuncInput
,
RequestFuncOutput
logger
=
init_logger
(
__name__
)
async
def
wait_for_endpoint
(
request_func
:
RequestFunc
,
...
...
@@ -61,6 +65,8 @@ async def wait_for_endpoint(
if
output
.
success
:
pbar
.
close
()
return
output
else
:
logger
.
warning
(
"Endpoint is not ready. Error='%s'"
,
output
.
error
)
except
aiohttp
.
ClientConnectorError
:
pass
...
...
vllm/config/parallel.py
View file @
4a5299c9
...
...
@@ -3,6 +3,7 @@
import
os
from
collections.abc
import
Callable
from
dataclasses
import
replace
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
import
torch
...
...
@@ -709,3 +710,6 @@ class ParallelConfig:
)
return
self
def
replace
(
self
,
**
kwargs
)
->
Self
:
return
replace
(
self
,
**
kwargs
)
vllm/config/speculative.py
View file @
4a5299c9
...
...
@@ -77,6 +77,9 @@ class SpeculativeConfig:
draft_tensor_parallel_size
:
int
|
None
=
Field
(
default
=
None
,
ge
=
1
)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
tensor_parallel_size
:
int
|
None
=
None
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
warn users when they mistakenly provide the wrong argument."""
# Draft model configuration
quantization
:
me_quant
.
QuantizationMethods
|
None
=
None
...
...
@@ -397,13 +400,11 @@ class SpeculativeConfig:
"one layer. Might need some code changes "
"to support multiple layers."
)
elif
self
.
method
==
"draft_model"
:
pass
else
:
self
.
method
=
"draft_model"
raise
NotImplementedError
(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or mtp."
f
"Unsupported speculative method: '
{
self
.
method
}
'"
)
# Replace hf_config for EAGLE draft_model
...
...
@@ -631,6 +632,12 @@ class SpeculativeConfig:
@
model_validator
(
mode
=
"after"
)
def
_verify_args
(
self
)
->
Self
:
if
self
.
tensor_parallel_size
is
not
None
:
raise
ValueError
(
"'tensor_parallel_size' is not a valid argument in the "
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
)
if
self
.
num_speculative_tokens
is
None
:
raise
ValueError
(
"num_speculative_tokens must be provided with "
...
...
@@ -669,12 +676,32 @@ class SpeculativeConfig:
f
"Eagle3 is only supported for
{
eagle3_target_supported
}
models. "
# noqa: E501
f
"Got
{
self
.
target_model_config
.
hf_text_config
.
model_type
=
}
"
)
self
.
verify_equal_vocab_size_if_draft_model
()
return
self
def
verify_equal_vocab_size_if_draft_model
(
self
):
if
(
self
.
method
==
"draft_model"
and
self
.
target_model_config
is
not
None
and
self
.
draft_model_config
is
not
None
):
target_vocab_size
=
self
.
target_model_config
.
get_vocab_size
()
draft_vocab_size
=
self
.
draft_model_config
.
get_vocab_size
()
if
target_vocab_size
!=
draft_vocab_size
:
raise
ValueError
(
f
"Target and draft model should have the same vocabulary size. "
f
"Target model vocab_size=
{
target_vocab_size
}
. "
f
"Draft model vocab_size=
{
draft_vocab_size
}
. "
f
"Using models with different tokenizers can cause out-of-bounds "
f
"errors during speculative decoding."
)
def
use_eagle
(
self
)
->
bool
:
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
)
def
uses_draft_model
(
self
)
->
bool
:
return
self
.
method
==
"draft_model"
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
model
=
None
if
method
in
(
"ngram"
,
"suffix"
)
else
self
.
draft_model_config
.
model
...
...
vllm/config/vllm.py
View file @
4a5299c9
...
...
@@ -1214,10 +1214,19 @@ class VllmConfig:
compilation_config
=
self
.
compilation_config
computed_compile_ranges_split_points
=
[]
# The upper bound of the compile ranges is the max_num_batched_tokens
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
if
max_num_batched_tokens
is
not
None
:
computed_compile_ranges_split_points
.
append
(
max_num_batched_tokens
)
# The upper bound of the compile ranges is the max_num_batched_tokens.
# For speculative decoding with draft model, the compile range must be extended
# by 1 for each sequence.
compile_range_end
=
self
.
scheduler_config
.
max_num_batched_tokens
if
compile_range_end
is
not
None
:
do_extend
:
bool
=
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
uses_draft_model
()
)
if
do_extend
:
compile_range_end
+=
self
.
scheduler_config
.
max_num_seqs
computed_compile_ranges_split_points
.
append
(
compile_range_end
)
# Add the compile ranges for flashinfer
if
compilation_config
.
pass_config
.
fuse_allreduce_rms
:
...
...
@@ -1228,10 +1237,7 @@ class VllmConfig:
self
.
model_config
.
get_hidden_size
()
*
self
.
model_config
.
dtype
.
itemsize
)
if
(
max_num_batched_tokens
is
not
None
and
max_token_num
<
max_num_batched_tokens
):
if
compile_range_end
is
not
None
and
max_token_num
<
compile_range_end
:
computed_compile_ranges_split_points
.
append
(
max_token_num
)
else
:
logger
.
debug
(
...
...
@@ -1243,11 +1249,7 @@ class VllmConfig:
for
x
in
compilation_config
.
compile_ranges_split_points
:
assert
isinstance
(
x
,
int
)
assert
x
>
0
,
f
"Invalid compile range split point:
{
x
}
"
if
(
max_num_batched_tokens
is
not
None
and
x
<
max_num_batched_tokens
and
x
>
1
):
if
compile_range_end
is
not
None
and
x
<
compile_range_end
and
x
>
1
:
computed_compile_ranges_split_points
.
append
(
x
)
compilation_config
.
compile_ranges_split_points
=
sorted
(
computed_compile_ranges_split_points
...
...
@@ -1316,6 +1318,14 @@ class VllmConfig:
path
=
self
.
compilation_config
.
debug_dump_path
/
append_path
return
path
def
replace
(
self
,
**
kwargs
):
"""
Replace attributes of the config, and 'recompute' the config.
dataclass.replace() calls __init__() and __post_init__(), source:
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
"""
return
replace
(
self
,
**
kwargs
)
def
__str__
(
self
):
return
(
f
"model=
{
self
.
model_config
.
model
!
r
}
, "
...
...
vllm/engine/arg_utils.py
View file @
4a5299c9
...
...
@@ -1776,21 +1776,6 @@ class EngineArgs:
):
_raise_unsupported_error
(
feature_name
=
"Concurrent Partial Prefill"
)
# N-gram, Medusa, and Eagle are supported for speculative decoding.
if
self
.
speculative_config
is
not
None
:
# speculative_config could still be a dict at this point
if
isinstance
(
self
.
speculative_config
,
dict
):
method
=
self
.
speculative_config
.
get
(
"method"
,
None
)
else
:
method
=
self
.
speculative_config
.
method
if
method
==
"draft_model"
:
raise
NotImplementedError
(
"Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or mtp."
)
if
self
.
pipeline_parallel_size
>
1
:
supports_pp
=
getattr
(
self
.
distributed_executor_backend
,
"supports_pp"
,
False
...
...
vllm/model_executor/model_loader/__init__.py
View file @
4a5299c9
...
...
@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
def
get_model
(
*
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
|
None
=
None
*
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
loader
=
get_model_loader
(
vllm_config
.
load_config
)
if
model_config
is
None
:
model_config
=
vllm_config
.
model_config
return
loader
.
load_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
)
return
loader
.
load_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
,
prefix
=
prefix
)
__all__
=
[
...
...
vllm/model_executor/model_loader/base_loader.py
View file @
4a5299c9
...
...
@@ -36,7 +36,7 @@ class BaseModelLoader(ABC):
raise
NotImplementedError
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
,
prefix
:
str
=
""
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
device_config
=
vllm_config
.
device_config
...
...
@@ -48,7 +48,7 @@ class BaseModelLoader(ABC):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
vllm_config
=
vllm_config
,
model_config
=
model_config
,
prefix
=
prefix
)
log_model_inspection
(
model
)
...
...
vllm/model_executor/model_loader/gguf_loader.py
View file @
4a5299c9
...
...
@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader):
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
,
prefix
:
str
=
""
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
local_model_path
=
self
.
_prepare_weights
(
model_config
)
...
...
@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader):
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model
=
initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
load_weights
(
model
,
model_config
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
...
...
vllm/model_executor/model_loader/tensorizer_loader.py
View file @
4a5299c9
...
...
@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader):
def
_load_model_serialized_cpu
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer to the CPU.
...
...
@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader):
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model
=
initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
...
...
@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader):
model
.
load_weights
(
self
.
_get_weights_iterator
())
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
,
prefix
:
str
=
""
)
->
nn
.
Module
:
parallel_config
=
vllm_config
.
parallel_config
self
.
_verify_config
(
model_config
,
parallel_config
)
...
...
@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader):
)
self
.
load_weights
(
model
,
model_config
)
return
model
return
self
.
_load_model_serialized_cpu
(
vllm_config
=
vllm_config
)
return
self
.
_load_model_serialized_cpu
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
@
staticmethod
def
save_model
(
...
...
vllm/v1/attention/backend.py
View file @
4a5299c9
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
replace
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Protocol
,
TypeVar
,
get_args
...
...
@@ -329,6 +329,16 @@ class CommonAttentionMetadata:
_num_computed_tokens_cache
:
torch
.
Tensor
|
None
=
None
def
batch_size
(
self
)
->
int
:
return
self
.
seq_lens
.
shape
[
0
]
def
naive_query_lens
(
self
)
->
torch
.
Tensor
:
"""Naive because it assumes that query ends where the next query starts."""
return
self
.
query_start_loc
[
1
:]
-
self
.
query_start_loc
[:
-
1
]
def
replace
(
self
,
**
kwargs
)
->
"CommonAttentionMetadata"
:
return
replace
(
self
,
**
kwargs
)
@
property
@
deprecated
(
"""
...
...
vllm/v1/attention/backends/utils.py
View file @
4a5299c9
...
...
@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens(
)
dcp_local_seq_lens
=
base
+
remainder
return
dcp_local_seq_lens
.
squeeze
(
1
)
def
extend_all_queries_by_1
(
common_attn_metadata
:
CommonAttentionMetadata
,
arange
:
torch
.
Tensor
,
new_slot_mapping
:
torch
.
Tensor
,
)
->
CommonAttentionMetadata
:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
Also all seq lens are increased by 1.
This is useful e.g. in speculative decoding with draft models, where we
extend each sequence by 1 token.
The slot mapping is computed externally, as it requires more information.
"""
cad
=
common_attn_metadata
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
new_query_start_loc
=
cad
.
query_start_loc
+
arange
[:
len
(
cad
.
query_start_loc
)]
new_query_start_loc_cpu
=
cad
.
query_start_loc_cpu
+
torch
.
arange
(
len
(
cad
.
query_start_loc_cpu
),
dtype
=
torch
.
int32
)
new_cad
=
cad
.
replace
(
query_start_loc
=
new_query_start_loc
,
query_start_loc_cpu
=
new_query_start_loc_cpu
,
seq_lens
=
cad
.
seq_lens
+
1
,
# each request is extended by 1 token -> batch_size tokens are added
num_actual_tokens
=
cad
.
num_actual_tokens
+
cad
.
batch_size
(),
# All query lens increase by 1, so max query len increases by 1
max_query_len
=
cad
.
max_query_len
+
1
,
max_seq_len
=
cad
.
max_seq_len
+
1
,
slot_mapping
=
new_slot_mapping
,
)
return
new_cad
vllm/v1/core/sched/scheduler.py
View file @
4a5299c9
...
...
@@ -208,6 +208,8 @@ class Scheduler(SchedulerInterface):
if
speculative_config
.
use_eagle
():
self
.
use_eagle
=
True
self
.
num_lookahead_tokens
=
self
.
num_spec_tokens
if
speculative_config
.
uses_draft_model
():
self
.
num_lookahead_tokens
=
self
.
num_spec_tokens
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
...
...
vllm/v1/spec_decode/draft_model.py
0 → 100644
View file @
4a5299c9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config.speculative
import
SpeculativeConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
extend_all_queries_by_1
,
)
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
SpecDecodeBaseProposer
logger
=
init_logger
(
__name__
)
class
DraftModelProposer
(
SpecDecodeBaseProposer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
,
):
super
().
__init__
(
vllm_config
=
vllm_config
,
device
=
device
,
pass_hidden_states_to_model
=
False
,
runner
=
runner
,
)
self
.
_raise_if_multimodal
()
self
.
_raise_if_mrope
()
self
.
_raise_if_padded_drafter_batch_disabled
()
self
.
_raise_if_vocab_size_mismatch
()
self
.
_raise_if_draft_tp_mismatch
()
def
_block_size
(
self
)
->
int
:
builder
=
self
.
_get_attention_metadata_builder
()
return
builder
.
kv_cache_spec
.
block_size
def
_raise_if_multimodal
(
self
):
if
self
.
supports_mm_inputs
:
raise
NotImplementedError
(
"Speculative Decoding with draft models "
"does not support multimodal models yet"
)
def
_raise_if_mrope
(
self
):
if
self
.
draft_model_config
.
uses_mrope
:
raise
NotImplementedError
(
"Speculative Decoding with draft models does not support M-RoPE yet"
)
def
_raise_if_padded_drafter_batch_disabled
(
self
):
if
self
.
vllm_config
.
speculative_config
.
disable_padded_drafter_batch
:
raise
NotImplementedError
(
"Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
" in the speculative_config."
)
def
_raise_if_vocab_size_mismatch
(
self
):
self
.
vllm_config
.
speculative_config
.
verify_equal_vocab_size_if_draft_model
()
def
_raise_if_draft_tp_mismatch
(
self
):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
# the draft model with TP = 1, then the different TP ranks collide.
# Specifically when all ranks compile the draft model on rank 0
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg
:
SpeculativeConfig
=
self
.
vllm_config
.
speculative_config
tgt_tp
=
spec_cfg
.
target_parallel_config
.
tensor_parallel_size
draft_tp
=
spec_cfg
.
draft_parallel_config
.
tensor_parallel_size
if
draft_tp
!=
tgt_tp
:
raise
ValueError
(
f
"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
f
"must be the same. Got
{
draft_tp
}
and
{
tgt_tp
}
. "
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)
def
set_inputs_first_pass
(
self
,
target_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
batch_size
=
cad
.
batch_size
()
grid
=
(
batch_size
,)
start_locs
=
cad
.
query_start_loc
[:
-
1
]
end_locs
=
cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
end_locs
-=
num_rejected_tokens_gpu
num_tokens
=
target_token_ids
.
shape
[
0
]
+
batch_size
is_rejected_tok
=
torch
.
empty
(
(
num_tokens
,),
device
=
self
.
input_ids
.
device
,
dtype
=
torch
.
bool
)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
target_token_ids
,
next_toks_ptr
=
next_token_ids
,
query_start_locs_ptr
=
start_locs
,
query_end_locs_ptr
=
end_locs
,
out_ptr_merged_toks
=
self
.
input_ids
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
target_token_ids
.
shape
[
0
],
# passing a negative rejected_tok_fill value will raise an error
# when the value is used to index into embeddings.
# Therefore, we pass a valid integer, e.g. 0.
rejected_tok_fill
=
0
,
)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
target_positions
,
next_toks_ptr
=
target_positions
[
end_locs
]
+
1
,
query_start_locs_ptr
=
start_locs
,
query_end_locs_ptr
=
end_locs
,
out_ptr_merged_toks
=
self
.
positions
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
target_positions
.
shape
[
0
],
rejected_tok_fill
=
0
,
)
# recompute slot mapping
new_slot_mapping
=
compute_new_slot_mapping
(
cad
=
cad
,
new_positions
=
self
.
positions
[:
num_tokens
],
is_rejected_token_mask
=
is_rejected_tok
,
block_size
=
self
.
_block_size
(),
max_model_len
=
self
.
max_model_len
,
)
# update common_attn_metadata
new_cad
:
CommonAttentionMetadata
=
extend_all_queries_by_1
(
cad
,
arange
=
self
.
arange
,
new_slot_mapping
=
new_slot_mapping
,
)
new_last_token_indices
=
new_cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
new_last_token_indices
-=
num_rejected_tokens_gpu
return
num_tokens
,
new_last_token_indices
,
new_cad
def
load_model
(
self
,
target_model
:
Any
)
->
None
:
"""Takes target_model to satisfy the type checker."""
# This must be computed before loading the draft model
# because that mutates the forward_context of the vllm_config
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
)
from
vllm.compilation.backends
import
set_model_tag
draft_vllm_config
:
VllmConfig
=
create_vllm_config_for_draft_model
(
target_model_vllm_config
=
self
.
vllm_config
)
logger
.
info
(
"Starting to load draft model %s. TP=%d, rank=%d"
,
draft_vllm_config
.
model_config
.
model
,
draft_vllm_config
.
parallel_config
.
tensor_parallel_size
,
draft_vllm_config
.
parallel_config
.
rank
,
)
with
set_model_tag
(
"draft_model"
):
self
.
model
=
get_model
(
vllm_config
=
draft_vllm_config
,
prefix
=
"draft_model"
)
# This must be computed after loading the draft model
# because that mutates the forward_context of the vllm_config
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
def
create_vllm_config_for_draft_model
(
target_model_vllm_config
:
VllmConfig
,
)
->
VllmConfig
:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the draft model.
The vllm_config is useful when loading the draft model with get_model().
"""
old
=
target_model_vllm_config
new_parallel_config
=
old
.
speculative_config
.
draft_parallel_config
.
replace
(
rank
=
old
.
parallel_config
.
rank
)
new
:
VllmConfig
=
old
.
replace
(
quant_config
=
None
,
# quant_config is recomputed in __init__()
model_config
=
old
.
speculative_config
.
draft_model_config
,
parallel_config
=
new_parallel_config
,
)
return
new
def
compute_new_slot_mapping
(
cad
:
CommonAttentionMetadata
,
new_positions
:
torch
.
Tensor
,
is_rejected_token_mask
:
torch
.
Tensor
,
block_size
:
int
,
max_model_len
:
int
,
):
batch_size
,
n_blocks_per_req
=
cad
.
block_table_tensor
.
shape
req_indices
=
torch
.
arange
(
batch_size
,
device
=
cad
.
query_start_loc
.
device
)
req_indices
=
torch
.
repeat_interleave
(
req_indices
,
cad
.
naive_query_lens
()
+
1
,
output_size
=
len
(
new_positions
)
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions
=
torch
.
clamp
(
new_positions
,
max
=
max_model_len
-
1
)
block_table_indices
=
(
req_indices
*
n_blocks_per_req
+
clamped_positions
//
block_size
)
block_nums
=
cad
.
block_table_tensor
.
view
(
-
1
)[
block_table_indices
]
block_offsets
=
clamped_positions
%
block_size
new_slot_mapping
=
block_nums
*
block_size
+
block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len
=
new_positions
>=
max_model_len
new_slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping
.
masked_fill_
(
is_rejected_token_mask
,
PADDING_SLOT_ID
)
return
new_slot_mapping
@
triton
.
jit
def
merge_toks_kernel
(
target_toks_ptr
,
next_toks_ptr
,
query_start_locs_ptr
,
query_end_locs_ptr
,
out_ptr_merged_toks
,
out_ptr_is_rejected_tok
,
target_toks_size
,
rejected_tok_fill
,
):
"""
Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor
called `out_ptr_merged_toks`. Rejected tokens are those after the
`query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the
rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask
of the rejected tokens in `out_ptr_is_rejected_tok`.
"""
pid
=
tl
.
program_id
(
0
)
start_loc
=
tl
.
load
(
query_start_locs_ptr
+
pid
)
is_last_program
=
pid
==
tl
.
num_programs
(
0
)
-
1
if
is_last_program
:
next_start_loc
=
target_toks_size
.
to
(
tl
.
int32
)
else
:
next_start_loc
=
tl
.
load
(
query_start_locs_ptr
+
pid
+
1
).
to
(
tl
.
int32
)
end_loc
=
tl
.
load
(
query_end_locs_ptr
+
pid
)
new_val
=
tl
.
load
(
next_toks_ptr
+
pid
)
for
i
in
range
(
start_loc
,
next_start_loc
+
1
):
if
i
<=
end_loc
:
# copy existing tokens
old_val
=
tl
.
load
(
target_toks_ptr
+
i
)
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
old_val
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
False
)
elif
i
==
end_loc
+
1
:
# copy bonus token
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
new_val
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
False
)
else
:
# fill rejected tokens
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
rejected_tok_fill
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
True
)
vllm/v1/spec_decode/eagle.py
View file @
4a5299c9
...
...
@@ -53,11 +53,12 @@ logger = init_logger(__name__)
PADDING_SLOT_ID
=
-
1
class
Eagl
eProposer
:
class
SpecDecodeBas
eProposer
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
pass_hidden_states_to_model
:
bool
,
runner
=
None
,
):
self
.
vllm_config
=
vllm_config
...
...
@@ -65,6 +66,7 @@ class EagleProposer:
assert
self
.
speculative_config
is
not
None
self
.
draft_model_config
=
self
.
speculative_config
.
draft_model_config
self
.
method
=
self
.
speculative_config
.
method
self
.
pass_hidden_states_to_model
=
pass_hidden_states_to_model
self
.
runner
=
runner
self
.
device
=
device
...
...
@@ -72,7 +74,11 @@ class EagleProposer:
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
self
.
token_arange_np
=
np
.
arange
(
self
.
max_num_tokens
)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
...
...
@@ -143,7 +149,6 @@ class EagleProposer:
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_slots_for_arange
=
max
(
max_batch_size
+
1
,
self
.
max_num_tokens
)
self
.
arange
=
torch
.
arange
(
max_num_slots_for_arange
,
device
=
device
,
dtype
=
torch
.
int32
...
...
@@ -245,11 +250,7 @@ class EagleProposer:
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
if
last_token_indices
is
None
:
last_token_indices
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
batch_size
=
common_attn_metadata
.
batch_size
()
if
self
.
method
==
"eagle3"
:
assert
isinstance
(
self
.
model
,
Eagle3LlamaForCausalLM
)
...
...
@@ -257,12 +258,17 @@ class EagleProposer:
target_hidden_states
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
num_tokens
,
last_token_indices
,
common_attn_metadata
=
(
self
.
set_inputs_first_pass
(
target_token_ids
=
target_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
last_token_indices
=
last_token_indices
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
)
)
assert
self
.
runner
is
not
None
...
...
@@ -311,9 +317,10 @@ class EagleProposer:
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
# copy inputs to buffer for cudagraph
self
.
_set_positions
(
num_tokens
,
target_positions
)
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
self
.
pass_hidden_states_to_model
:
# target_hidden_states and self.hidden_states can have different
# hidden dims. E.g. large target model and small draft model.
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
...
...
@@ -330,6 +337,14 @@ class EagleProposer:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
...
...
@@ -337,17 +352,13 @@ class EagleProposer:
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
):
ret_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
self
.
_get_positions
(
num_input_tokens
),
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
inputs_embeds
=
inputs_embeds
,
)
if
self
.
method
==
"mtp"
:
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
if
not
self
.
model_returns_tuple
():
last_hidden_states
=
ret_hidden_states
hidden_states
=
last_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
...
...
@@ -357,9 +368,9 @@ class EagleProposer:
return
draft_token_ids
.
view
(
-
1
,
1
)
if
self
.
uses_mrope
:
positions
=
target_
positions
[:,
last_token_indices
]
positions
=
self
.
positions
[:,
last_token_indices
]
else
:
positions
=
target_
positions
[
last_token_indices
]
positions
=
self
.
positions
[
last_token_indices
]
if
self
.
method
in
(
"deepseek_mtp"
,
"ernie_mtp"
,
...
...
@@ -527,6 +538,14 @@ class EagleProposer:
inputs_embeds
=
None
# Run the model.
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
input_batch_size
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
input_batch_size
]
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
...
...
@@ -534,17 +553,13 @@ class EagleProposer:
num_tokens_across_dp
=
batch_size_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
):
ret_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
self
.
_get_positions
(
input_batch_size
),
hidden_states
=
self
.
hidden_states
[:
input_batch_size
],
inputs_embeds
=
inputs_embeds
,
)
if
self
.
method
==
"mtp"
:
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
if
not
self
.
model_returns_tuple
():
last_hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
])
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
...
...
@@ -554,6 +569,34 @@ class EagleProposer:
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
def
set_inputs_first_pass
(
self
,
target_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
if
last_token_indices
is
None
:
last_token_indices
=
cad
.
query_start_loc
[
1
:]
-
1
num_tokens
=
target_token_ids
.
shape
[
0
]
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
# copy inputs to buffer for cudagraph
self
.
_set_positions
(
num_tokens
,
target_positions
)
return
num_tokens
,
last_token_indices
,
cad
def
model_returns_tuple
(
self
)
->
bool
:
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
)
def
prepare_next_token_ids_cpu
(
self
,
sampled_token_ids
:
list
[
list
[
int
]],
...
...
@@ -1214,12 +1257,14 @@ class EagleProposer:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
self
.
model
(
kwargs
=
dict
(
input_ids
=
input_ids
,
positions
=
self
.
_get_positions
(
num_input_tokens
),
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
inputs_embeds
=
inputs_embeds
,
)
if
self
.
pass_hidden_states_to_model
:
kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
self
.
model
(
**
kwargs
)
def
_get_attention_metadata_builder
(
self
)
->
AttentionMetadataBuilder
:
"""Find and return the attention metadata builders for EAGLE layers.
...
...
@@ -1264,8 +1309,8 @@ class EagleProposer:
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Validate that all
eagle
layers belong to the same KVCacheGroup.
Need this assumption to ensure all
eagle
layers can use the
Validate that all
drafting
layers belong to the same KVCacheGroup.
Need this assumption to ensure all
drafting
layers can use the
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
"""
...
...
@@ -1283,7 +1328,7 @@ class EagleProposer:
)
)
==
1
),
"All
eagle
layers should belong to the same kv cache group"
),
"All
drafting
layers should belong to the same kv cache group"
def
_pad_batch_across_dp
(
self
,
...
...
@@ -1308,6 +1353,21 @@ class EagleProposer:
return
num_tokens_dp_padded
,
num_toks_across_dp
class
EagleProposer
(
SpecDecodeBaseProposer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
,
):
super
().
__init__
(
vllm_config
,
device
,
pass_hidden_states_to_model
=
True
,
runner
=
runner
,
)
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4a5299c9
...
...
@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
...
...
@@ -432,10 +433,20 @@ class GPUModelRunner(
# layers in the draft model.
if
self
.
speculative_config
and
get_pp_group
().
is_last_rank
:
self
.
drafter
:
(
NgramProposer
|
SuffixDecodingProposer
|
EagleProposer
|
MedusaProposer
NgramProposer
|
SuffixDecodingProposer
|
EagleProposer
|
DraftModelProposer
|
MedusaProposer
)
if
self
.
speculative_config
.
method
==
"ngram"
:
self
.
drafter
=
NgramProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
uses_draft_model
():
self
.
drafter
=
DraftModelProposer
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
,
runner
=
self
,
)
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
...
...
@@ -3443,10 +3454,13 @@ class GPUModelRunner(
spec_decode_common_attn_metadata
.
max_seq_len
+
self
.
num_spec_tokens
<=
self
.
effective_drafter_max_model_len
)
if
spec_config
.
use_eagle
()
and
not
spec_config
.
disable_padded_drafter_batch
:
# EAGLE speculative decoding can use the GPU sampled tokens
use_gpu_toks
=
(
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
()
)
and
not
spec_config
.
disable_padded_drafter_batch
if
use_gpu_toks
:
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
input_fits_in_drafter
:
propose_draft_token_ids
(
sampled_token_ids
)
...
...
@@ -3679,8 +3693,8 @@ class GPUModelRunner(
target_hidden_states
=
hidden_states
,
sampling_metadata
=
sampling_metadata
,
)
elif
spec_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
elif
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
()
:
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
if
spec_config
.
disable_padded_drafter_batch
:
# When padded-batch is disabled, the sampled_token_ids should be
...
...
@@ -4475,8 +4489,12 @@ class GPUModelRunner(
else
:
hidden_states
=
outputs
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
if
self
.
speculative_config
and
(
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_draft_model
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
assert
self
.
speculative_config
is
not
None
# Eagle currently only supports PIECEWISE cudagraphs.
# Therefore only use cudagraphs if the main model uses PIECEWISE
# NOTE(lucas): this is a hack, need to clean up.
...
...
@@ -5652,8 +5670,11 @@ class GPUModelRunner(
kv_cache_config
,
kernel_block_sizes
)
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
if
self
.
speculative_config
and
(
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_draft_model
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
# validate all draft model layers belong to the same kv cache
# group
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment