Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4a5299c9
Unverified
Commit
4a5299c9
authored
Jan 19, 2026
by
Tomas Ruiz
Committed by
GitHub
Jan 19, 2026
Browse files
feat: spec decode with draft models (#24322)
Signed-off-by:
Tomas Ruiz
<
tomas.ruiz.te@gmail.com
>
parent
73f2a81c
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
895 additions
and
113 deletions
+895
-113
examples/offline_inference/spec_decode.py
examples/offline_inference/spec_decode.py
+17
-2
examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py
..._serving/disaggregated_serving/moriio_toy_proxy_server.py
+1
-1
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+306
-9
tests/v1/worker/test_utils.py
tests/v1/worker/test_utils.py
+35
-0
vllm/benchmarks/datasets.py
vllm/benchmarks/datasets.py
+10
-9
vllm/benchmarks/lib/ready_checker.py
vllm/benchmarks/lib/ready_checker.py
+6
-0
vllm/config/parallel.py
vllm/config/parallel.py
+4
-0
vllm/config/speculative.py
vllm/config/speculative.py
+33
-6
vllm/config/vllm.py
vllm/config/vllm.py
+23
-13
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-15
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+7
-2
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+2
-2
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+2
-2
vllm/model_executor/model_loader/tensorizer_loader.py
vllm/model_executor/model_loader/tensorizer_loader.py
+4
-3
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+11
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+32
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+2
-0
vllm/v1/spec_decode/draft_model.py
vllm/v1/spec_decode/draft_model.py
+271
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+98
-38
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+31
-10
No files found.
examples/offline_inference/spec_decode.py
View file @
4a5299c9
...
@@ -54,7 +54,7 @@ def parse_args():
...
@@ -54,7 +54,7 @@ def parse_args():
"--method"
,
"--method"
,
type
=
str
,
type
=
str
,
default
=
"eagle"
,
default
=
"eagle"
,
choices
=
[
"ngram"
,
"eagle"
,
"eagle3"
,
"mtp"
],
choices
=
[
"ngram"
,
"eagle"
,
"eagle3"
,
"mtp"
,
"draft_model"
],
)
)
parser
.
add_argument
(
"--num-spec-tokens"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--num-spec-tokens"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--prompt-lookup-max"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--prompt-lookup-max"
,
type
=
int
,
default
=
5
)
...
@@ -70,7 +70,11 @@ def parse_args():
...
@@ -70,7 +70,11 @@ def parse_args():
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--model-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--model-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--eagle-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--eagle-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--draft-model"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--custom-mm-prompts"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--custom-mm-prompts"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--gpu-memory-utilization"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--disable-padded-drafter-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--max-num-seqs"
,
type
=
int
,
default
=
None
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -111,6 +115,7 @@ def main(args):
...
@@ -111,6 +115,7 @@ def main(args):
"method"
:
args
.
method
,
"method"
:
args
.
method
,
"model"
:
eagle_dir
,
"model"
:
eagle_dir
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"disable_padded_drafter_batch"
:
args
.
disable_padded_drafter_batch
,
}
}
elif
args
.
method
==
"ngram"
:
elif
args
.
method
==
"ngram"
:
speculative_config
=
{
speculative_config
=
{
...
@@ -119,6 +124,15 @@ def main(args):
...
@@ -119,6 +124,15 @@ def main(args):
"prompt_lookup_max"
:
args
.
prompt_lookup_max
,
"prompt_lookup_max"
:
args
.
prompt_lookup_max
,
"prompt_lookup_min"
:
args
.
prompt_lookup_min
,
"prompt_lookup_min"
:
args
.
prompt_lookup_min
,
}
}
elif
args
.
method
==
"draft_model"
:
assert
args
.
draft_model
is
not
None
and
args
.
draft_model
!=
""
speculative_config
=
{
"method"
:
args
.
method
,
"model"
:
args
.
draft_model
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"enforce_eager"
:
args
.
enforce_eager
,
"max_model_len"
:
args
.
max_model_len
,
}
elif
args
.
method
==
"mtp"
:
elif
args
.
method
==
"mtp"
:
speculative_config
=
{
speculative_config
=
{
"method"
:
"mtp"
,
"method"
:
"mtp"
,
...
@@ -133,12 +147,13 @@ def main(args):
...
@@ -133,12 +147,13 @@ def main(args):
tensor_parallel_size
=
args
.
tp
,
tensor_parallel_size
=
args
.
tp
,
enable_chunked_prefill
=
args
.
enable_chunked_prefill
,
enable_chunked_prefill
=
args
.
enable_chunked_prefill
,
enforce_eager
=
args
.
enforce_eager
,
enforce_eager
=
args
.
enforce_eager
,
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
args
.
gpu_memory_utilization
,
speculative_config
=
speculative_config
,
speculative_config
=
speculative_config
,
disable_log_stats
=
False
,
disable_log_stats
=
False
,
max_model_len
=
args
.
max_model_len
,
max_model_len
=
args
.
max_model_len
,
limit_mm_per_prompt
=
{
"image"
:
5
},
limit_mm_per_prompt
=
{
"image"
:
5
},
disable_chunked_mm_input
=
True
,
disable_chunked_mm_input
=
True
,
max_num_seqs
=
args
.
max_num_seqs
,
)
)
sampling_params
=
SamplingParams
(
temperature
=
args
.
temp
,
max_tokens
=
args
.
output_len
)
sampling_params
=
SamplingParams
(
temperature
=
args
.
temp
,
max_tokens
=
args
.
output_len
)
...
...
examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py
View file @
4a5299c9
...
@@ -4,13 +4,13 @@ import asyncio
...
@@ -4,13 +4,13 @@ import asyncio
import
copy
import
copy
import
logging
import
logging
import
os
import
os
import
re
import
socket
import
socket
import
threading
import
threading
import
uuid
import
uuid
import
aiohttp
import
aiohttp
import
msgpack
import
msgpack
import
regex
as
re
import
zmq
import
zmq
from
quart
import
Quart
,
make_response
,
request
from
quart
import
Quart
,
make_response
,
request
...
...
tests/v1/e2e/test_spec_decode.py
View file @
4a5299c9
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
random
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
typing
import
Any
from
typing
import
Any
import
pytest
import
pytest
...
@@ -10,32 +12,45 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
...
@@ -10,32 +12,45 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
from
vllm.assets.image
import
VLM_IMAGES_DIR
from
vllm.assets.image
import
VLM_IMAGES_DIR
from
vllm.benchmarks.datasets
import
InstructCoderDataset
from
vllm.config.vllm
import
VllmConfig
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.metrics.reader
import
Metric
from
vllm.v1.spec_decode.draft_model
import
(
create_vllm_config_for_draft_model
,
merge_toks_kernel
,
)
MTP_SIMILARITY_RATE
=
0.8
MTP_SIMILARITY_RATE
=
0.8
def
_skip_if_insufficient_gpus_for_tp
(
tp_size
:
int
):
def
_skip_if_insufficient_gpus_for_tp
(
tp_size
:
int
):
"""Skip test if available GPUs < tp_size on ROCm."""
"""Skip test if available GPUs < tp_size on ROCm."""
if
current_platform
.
is_rocm
():
available_gpus
=
torch
.
cuda
.
device_count
()
available_gpus
=
torch
.
cuda
.
device_count
()
if
available_gpus
<
tp_size
:
if
available_gpus
<
tp_size
:
pytest
.
skip
(
pytest
.
skip
(
f
"Test requires
{
tp_size
}
GPUs, but only
{
available_gpus
}
available"
f
"Test requires
{
tp_size
}
GPUs, but only
{
available_gpus
}
available"
)
)
def
get_test_prompts
(
mm_enabled
:
bool
):
Messages
=
list
[
dict
[
str
,
Any
]]
def
get_test_prompts
(
mm_enabled
:
bool
,
quiet
:
bool
=
False
,
num_prompts
:
int
=
100
)
->
list
[
Messages
]:
prompt_types
=
[
"repeat"
,
"sentence"
]
prompt_types
=
[
"repeat"
,
"sentence"
]
if
mm_enabled
:
if
mm_enabled
:
prompt_types
.
append
(
"mm"
)
prompt_types
.
append
(
"mm"
)
num_prompts
=
100
prompts
=
[]
prompts
=
[]
random
.
seed
(
0
)
random
.
seed
(
0
)
random_prompt_type_choices
=
random
.
choices
(
prompt_types
,
k
=
num_prompts
)
random_prompt_type_choices
=
random
.
choices
(
prompt_types
,
k
=
num_prompts
)
print
(
f
"Prompt types:
{
random_prompt_type_choices
}
"
)
if
not
quiet
:
print
(
f
"Prompt types:
{
random_prompt_type_choices
}
"
)
# Generate a mixed batch of prompts, some of which can be easily
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
# predicted by n-gram matching and some which likely cannot.
...
@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool):
...
@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool):
return
prompts
return
prompts
def
get_instruct_coder_messages
(
n
:
int
)
->
list
[
Messages
]:
dataset
=
InstructCoderDataset
(
dataset_path
=
"likaixin/InstructCoder"
,
dataset_split
=
"train"
)
prompts
:
Iterable
[
str
]
=
dataset
.
sample_prompts
(
n
=
n
)
return
[[{
"role"
:
"user"
,
"content"
:
prompt
}]
for
prompt
in
prompts
]
@
pytest
.
fixture
@
pytest
.
fixture
def
sampling_config
():
def
sampling_config
():
return
greedy_sampling
()
def
greedy_sampling
()
->
SamplingParams
:
return
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
,
ignore_eos
=
False
)
return
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
,
ignore_eos
=
False
)
def
stochastic_sampling
()
->
SamplingParams
:
return
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
10
,
ignore_eos
=
False
)
@
pytest
.
fixture
@
pytest
.
fixture
def
model_name
():
def
model_name
():
return
"meta-llama/Llama-3.1-8B-Instruct"
return
"meta-llama/Llama-3.1-8B-Instruct"
...
@@ -583,3 +614,269 @@ def test_mtp_correctness(
...
@@ -583,3 +614,269 @@ def test_mtp_correctness(
del
spec_llm
del
spec_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
@
dataclass
class
ArgsTest
:
target_model
:
str
draft_model
:
str
sampling_config
:
SamplingParams
num_speculative_tokens
:
int
expected_acceptance_rate
:
float
expected_acceptance_len
:
float
# Defaults
target_tensor_parallel_size
:
int
=
1
draft_tensor_parallel_size
:
int
=
1
max_model_len
:
int
=
1024
gpu_memory_utilization
:
float
=
0.5
dataset
:
str
=
"test_prompts"
num_prompts
:
int
=
100
cases
=
[
# Same model for draft and target, greedy sampling.
ArgsTest
(
target_model
=
"Qwen/Qwen3-0.6B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
sampling_config
=
greedy_sampling
(),
num_speculative_tokens
=
3
,
# K
expected_acceptance_len
=
3
+
1
,
# K + 1
expected_acceptance_rate
=
1.0
,
),
# Smaller draft model, stochastic sampling.
ArgsTest
(
target_model
=
"Qwen/Qwen3-1.7B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
sampling_config
=
stochastic_sampling
(),
num_speculative_tokens
=
3
,
expected_acceptance_len
=
2.8
+
1
,
expected_acceptance_rate
=
0.9
,
),
]
@
pytest
.
mark
.
parametrize
(
"args"
,
cases
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_draft_model_correctness
(
args
:
ArgsTest
,
enforce_eager
:
bool
):
assert_draft_model_correctness
(
args
,
enforce_eager
)
def
test_draft_model_realistic_example
():
args
=
ArgsTest
(
target_model
=
"Qwen/Qwen3-1.7B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
dataset
=
"likaixin/InstructCoder"
,
num_speculative_tokens
=
3
,
sampling_config
=
greedy_sampling
(),
# values below are not derived, but just prevent a regression
expected_acceptance_len
=
2.8
,
expected_acceptance_rate
=
0.55
,
)
assert_draft_model_correctness
(
args
,
enforce_eager
=
False
)
@
pytest
.
mark
.
parametrize
(
"models"
,
[
# target_model, draft_model
(
"Qwen/Qwen3-1.7B-FP8"
,
"Qwen/Qwen3-0.6B"
),
# target quantized
(
"Qwen/Qwen3-1.7B"
,
"Qwen/Qwen3-0.6B-FP8"
),
# draft quantized
],
ids
=
[
"target_quantized"
,
"draft_quantized"
],
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_draft_model_quantization
(
models
:
tuple
[
str
,
str
],
enforce_eager
:
bool
):
tgt_model
,
draft_model
=
models
sd_case
=
ArgsTest
(
target_model
=
tgt_model
,
draft_model
=
draft_model
,
**
some_high_acceptance_metrics
(),
)
assert_draft_model_correctness
(
sd_case
,
enforce_eager
)
def
test_draft_model_tensor_parallelism
():
"""Ensure spec decode works when running with TP > 1."""
_skip_if_insufficient_gpus_for_tp
(
2
)
sd_case
=
ArgsTest
(
target_model
=
"Qwen/Qwen3-1.7B"
,
target_tensor_parallel_size
=
2
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
draft_tensor_parallel_size
=
2
,
**
some_high_acceptance_metrics
(),
)
assert_draft_model_correctness
(
sd_case
,
enforce_eager
=
False
)
def
test_draft_model_engine_args_tensor_parallelism
():
"""Ensure the vllm_config for the draft model is created correctly,
and independently of the target model (quantization, TP, etc.)"""
_skip_if_insufficient_gpus_for_tp
(
2
)
engine_args
=
EngineArgs
(
model
=
"Qwen/Qwen3-1.7B-FP8"
,
# <<< tgt quantized
tensor_parallel_size
=
2
,
speculative_config
=
{
"model"
:
"Qwen/Qwen3-0.6B"
,
# <<< draft not quantized
"method"
:
"draft_model"
,
"num_speculative_tokens"
:
3
,
"draft_tensor_parallel_size"
:
1
,
# <<< valid arg name
},
)
tgt_vllm_config
:
VllmConfig
=
engine_args
.
create_engine_config
()
assert
tgt_vllm_config
.
parallel_config
.
tensor_parallel_size
==
2
assert
tgt_vllm_config
.
quant_config
.
get_name
()
==
"fp8"
draft_vllm_config
:
VllmConfig
=
create_vllm_config_for_draft_model
(
tgt_vllm_config
)
assert
draft_vllm_config
.
parallel_config
.
tensor_parallel_size
==
1
assert
draft_vllm_config
.
quant_config
is
None
def
test_draft_model_engine_args_rejects_invalid_tp_argname
():
"""The user should pass "draft_tensor_parallel_size" rather than
"tensor_parallel_size". We enforce this with validation."""
engine_args
=
EngineArgs
(
model
=
"Qwen/Qwen3-1.7B"
,
tensor_parallel_size
=
1
,
speculative_config
=
{
"model"
:
"Qwen/Qwen3-0.6B"
,
"method"
:
"draft_model"
,
"num_speculative_tokens"
:
3
,
"tensor_parallel_size"
:
1
,
# <<< invalid arg name
},
)
with
pytest
.
raises
(
ValueError
):
engine_args
.
create_engine_config
()
def
assert_draft_model_correctness
(
args
:
ArgsTest
,
enforce_eager
:
bool
):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts
:
list
[
Messages
]
=
get_messages
(
dataset
=
args
.
dataset
,
n
=
args
.
num_prompts
)
spec_llm
=
LLM
(
model
=
args
.
target_model
,
speculative_config
=
{
"model"
:
args
.
draft_model
,
"method"
:
"draft_model"
,
"num_speculative_tokens"
:
args
.
num_speculative_tokens
,
"max_model_len"
:
args
.
max_model_len
,
"enforce_eager"
:
enforce_eager
,
"draft_tensor_parallel_size"
:
args
.
draft_tensor_parallel_size
,
"max_num_seqs"
:
100
,
# limit cudagraph capture runtime
},
max_model_len
=
args
.
max_model_len
,
gpu_memory_utilization
=
args
.
gpu_memory_utilization
,
tensor_parallel_size
=
args
.
target_tensor_parallel_size
,
enforce_eager
=
enforce_eager
,
disable_log_stats
=
False
,
# enables get_metrics()
)
# we don't check the outputs, only check the metrics
spec_llm
.
chat
(
test_prompts
,
args
.
sampling_config
)
metrics
=
spec_llm
.
get_metrics
()
acceptance_rate
:
float
=
compute_acceptance_rate
(
metrics
)
acceptance_len
:
float
=
compute_acceptance_len
(
metrics
)
del
spec_llm
# CLEANUP
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
print
(
f
"spec-decode: target=
{
args
.
target_model
}
, draft=
{
args
.
draft_model
}
, "
f
"temperature=
{
args
.
sampling_config
.
temperature
:.
2
f
}
, "
f
"acceptance_rate=
{
acceptance_rate
:.
2
f
}
, "
f
"acceptance_len=
{
acceptance_len
:.
2
f
}
, "
)
assert
acceptance_rate
>=
args
.
expected_acceptance_rate
assert
acceptance_len
>=
args
.
expected_acceptance_len
def
get_messages
(
dataset
:
str
,
n
:
int
)
->
list
[
Messages
]:
if
dataset
==
"test_prompts"
:
return
get_test_prompts
(
mm_enabled
=
False
,
quiet
=
True
,
num_prompts
=
n
)
elif
dataset
==
"likaixin/InstructCoder"
:
return
get_instruct_coder_messages
(
n
=
n
)
else
:
raise
NotImplementedError
(
f
"Dataset '
{
dataset
}
' not implemented"
)
def
some_high_acceptance_metrics
()
->
dict
:
return
{
"sampling_config"
:
greedy_sampling
(),
"num_speculative_tokens"
:
3
,
"expected_acceptance_len"
:
2.90
+
1
,
"expected_acceptance_rate"
:
0.90
,
}
def
test_merge_toks_kernel
():
device
=
"cuda"
merged_len
=
5
+
2
# len(target_toks) = 5, batch_size = 2
merged
=
torch
.
full
((
merged_len
,),
-
100
,
device
=
device
)
# -100 is arbitrary
is_rejected_tok
=
torch
.
full
((
merged_len
,),
True
,
device
=
device
)
grid
=
(
2
,)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
],
device
=
device
),
next_toks_ptr
=
torch
.
tensor
([
3
,
2
],
device
=
device
),
query_start_locs_ptr
=
torch
.
tensor
([
0
,
3
],
device
=
device
),
query_end_locs_ptr
=
torch
.
tensor
([
2
,
4
],
device
=
device
),
out_ptr_merged_toks
=
merged
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
5
,
rejected_tok_fill
=-
1
,
)
expected_merged
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
1
,
2
],
device
=
device
)
assert
torch
.
allclose
(
merged
,
expected_merged
)
expected_rejected_toks
=
torch
.
tensor
([
False
]
*
merged_len
,
device
=
device
)
assert
torch
.
allclose
(
is_rejected_tok
,
expected_rejected_toks
)
def
test_merge_toks_kernel_with_rejected_tokens
():
device
=
"cuda"
merged_size
=
9
+
2
# len(target_toks) = 9, batch_size = 2
merged
=
torch
.
full
((
merged_size
,),
-
100
,
device
=
device
)
is_rejected_tok
=
torch
.
full
((
merged_size
,),
True
,
device
=
device
)
grid
=
(
2
,)
merge_toks_kernel
[
grid
](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr
=
torch
.
tensor
([
0
,
1
,
2
,
13
,
14
,
15
,
0
,
1
,
22
],
device
=
device
),
next_toks_ptr
=
torch
.
tensor
([
3
,
2
],
device
=
device
),
query_start_locs_ptr
=
torch
.
tensor
([
0
,
6
],
device
=
device
),
query_end_locs_ptr
=
torch
.
tensor
([
2
,
7
],
device
=
device
),
out_ptr_merged_toks
=
merged
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
9
,
rejected_tok_fill
=-
1
,
)
expected_merged
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
-
1
,
-
1
,
-
1
,
0
,
1
,
2
,
-
1
],
device
=
device
)
assert
torch
.
allclose
(
merged
,
expected_merged
)
expected_rejected_toks
=
torch
.
tensor
(
[
False
,
False
,
False
,
False
,
True
,
True
,
True
,
False
,
False
,
False
,
True
],
device
=
device
,
)
assert
torch
.
allclose
(
is_rejected_tok
,
expected_rejected_toks
)
def
compute_acceptance_rate
(
metrics
:
list
[
Metric
])
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_draft_toks
=
name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
# type: ignore
if
n_draft_toks
==
0
:
return
float
(
"nan"
)
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
return
n_accepted_toks
/
n_draft_toks
def
compute_acceptance_len
(
metrics
:
list
[
Metric
])
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_drafts
=
name2metric
[
"vllm:spec_decode_num_drafts"
].
value
# type: ignore
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
if
n_drafts
==
0
:
return
1
return
1
+
(
n_accepted_toks
/
n_drafts
)
tests/v1/worker/test_utils.py
View file @
4a5299c9
...
@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
...
@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
assert
runner_kv_caches
[
0
]
is
kv_cache
[
"model.layers.20.attn"
]
assert
runner_kv_caches
[
0
]
is
kv_cache
[
"model.layers.20.attn"
]
assert
runner_kv_caches
[
1
]
is
kv_cache
[
"model.layers.28.attn"
]
assert
runner_kv_caches
[
1
]
is
kv_cache
[
"model.layers.28.attn"
]
def
test_bind_kv_cache_draft_model
(
default_vllm_config
):
from
vllm.attention.layer
import
Attention
layer_names
=
[
"model.layers.0.attn"
,
"model.layers.1.attn"
,
"draft_model.layers.0.attn"
,
"draft_model.layers.1.attn"
,
]
ctx
=
{
layer_name
:
Attention
(
32
,
128
,
0.1
,
prefix
=
layer_name
)
for
layer_name
in
layer_names
}
kv_cache
=
{
layer_name
:
torch
.
zeros
((
1
,))
for
layer_name
in
layer_names
}
runner_kv_caches
:
list
[
torch
.
Tensor
]
=
[]
bind_kv_cache
(
kv_cache
,
ctx
,
runner_kv_caches
)
assert
ctx
[
"model.layers.0.attn"
].
kv_cache
[
0
]
is
kv_cache
[
"model.layers.0.attn"
]
assert
ctx
[
"model.layers.1.attn"
].
kv_cache
[
0
]
is
kv_cache
[
"model.layers.1.attn"
]
assert
(
ctx
[
"draft_model.layers.0.attn"
].
kv_cache
[
0
]
is
kv_cache
[
"draft_model.layers.0.attn"
]
)
assert
(
ctx
[
"draft_model.layers.1.attn"
].
kv_cache
[
0
]
is
kv_cache
[
"draft_model.layers.1.attn"
]
)
# caches are ordered by layer_index, interleaving target and draft model
assert
runner_kv_caches
[
0
]
is
kv_cache
[
"model.layers.0.attn"
]
assert
runner_kv_caches
[
1
]
is
kv_cache
[
"draft_model.layers.0.attn"
]
assert
runner_kv_caches
[
2
]
is
kv_cache
[
"model.layers.1.attn"
]
assert
runner_kv_caches
[
3
]
is
kv_cache
[
"draft_model.layers.1.attn"
]
vllm/benchmarks/datasets.py
View file @
4a5299c9
...
@@ -2593,17 +2593,10 @@ class InstructCoderDataset(HuggingFaceDataset):
...
@@ -2593,17 +2593,10 @@ class InstructCoderDataset(HuggingFaceDataset):
request_id_prefix
:
str
=
""
,
request_id_prefix
:
str
=
""
,
no_oversample
:
bool
=
False
,
no_oversample
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
list
:
)
->
list
[
SampleRequest
]
:
output_len
=
output_len
if
output_len
is
not
None
else
self
.
DEFAULT_OUTPUT_LEN
output_len
=
output_len
if
output_len
is
not
None
else
self
.
DEFAULT_OUTPUT_LEN
sampled_requests
=
[]
sampled_requests
=
[]
for
i
,
item
in
enumerate
(
self
.
data
):
for
i
,
prompt
in
enumerate
(
self
.
sample_prompts
(
n
=
num_requests
)):
if
len
(
sampled_requests
)
>=
num_requests
:
break
prompt
=
(
f
"
{
item
[
'input'
]
}
\n\n
{
item
[
'instruction'
]
}
Just output "
"the code, do not include any explanation."
)
# apply template
# apply template
if
not
skip_chat_template
:
if
not
skip_chat_template
:
prompt
=
tokenizer
.
apply_chat_template
(
prompt
=
tokenizer
.
apply_chat_template
(
...
@@ -2626,6 +2619,14 @@ class InstructCoderDataset(HuggingFaceDataset):
...
@@ -2626,6 +2619,14 @@ class InstructCoderDataset(HuggingFaceDataset):
)
)
return
sampled_requests
return
sampled_requests
def
sample_prompts
(
self
,
n
:
int
)
->
Iterator
[
str
]:
for
item
in
self
.
data
.
take
(
n
):
prompt
=
(
f
"
{
item
[
'input'
]
}
\n\n
{
item
[
'instruction'
]
}
Just output "
"the code, do not include any explanation."
)
yield
prompt
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# MT-Bench Dataset Implementation
...
...
vllm/benchmarks/lib/ready_checker.py
View file @
4a5299c9
...
@@ -8,8 +8,12 @@ import time
...
@@ -8,8 +8,12 @@ import time
import
aiohttp
import
aiohttp
from
tqdm.asyncio
import
tqdm
from
tqdm.asyncio
import
tqdm
from
vllm.logger
import
init_logger
from
.endpoint_request_func
import
RequestFunc
,
RequestFuncInput
,
RequestFuncOutput
from
.endpoint_request_func
import
RequestFunc
,
RequestFuncInput
,
RequestFuncOutput
logger
=
init_logger
(
__name__
)
async
def
wait_for_endpoint
(
async
def
wait_for_endpoint
(
request_func
:
RequestFunc
,
request_func
:
RequestFunc
,
...
@@ -61,6 +65,8 @@ async def wait_for_endpoint(
...
@@ -61,6 +65,8 @@ async def wait_for_endpoint(
if
output
.
success
:
if
output
.
success
:
pbar
.
close
()
pbar
.
close
()
return
output
return
output
else
:
logger
.
warning
(
"Endpoint is not ready. Error='%s'"
,
output
.
error
)
except
aiohttp
.
ClientConnectorError
:
except
aiohttp
.
ClientConnectorError
:
pass
pass
...
...
vllm/config/parallel.py
View file @
4a5299c9
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
os
import
os
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
dataclasses
import
replace
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
import
torch
import
torch
...
@@ -709,3 +710,6 @@ class ParallelConfig:
...
@@ -709,3 +710,6 @@ class ParallelConfig:
)
)
return
self
return
self
def
replace
(
self
,
**
kwargs
)
->
Self
:
return
replace
(
self
,
**
kwargs
)
vllm/config/speculative.py
View file @
4a5299c9
...
@@ -77,6 +77,9 @@ class SpeculativeConfig:
...
@@ -77,6 +77,9 @@ class SpeculativeConfig:
draft_tensor_parallel_size
:
int
|
None
=
Field
(
default
=
None
,
ge
=
1
)
draft_tensor_parallel_size
:
int
|
None
=
Field
(
default
=
None
,
ge
=
1
)
"""The degree of the tensor parallelism for the draft model. Can only be 1
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
or the same as the target model's tensor parallel size."""
tensor_parallel_size
:
int
|
None
=
None
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
warn users when they mistakenly provide the wrong argument."""
# Draft model configuration
# Draft model configuration
quantization
:
me_quant
.
QuantizationMethods
|
None
=
None
quantization
:
me_quant
.
QuantizationMethods
|
None
=
None
...
@@ -397,13 +400,11 @@ class SpeculativeConfig:
...
@@ -397,13 +400,11 @@ class SpeculativeConfig:
"one layer. Might need some code changes "
"one layer. Might need some code changes "
"to support multiple layers."
"to support multiple layers."
)
)
elif
self
.
method
==
"draft_model"
:
pass
else
:
else
:
self
.
method
=
"draft_model"
raise
NotImplementedError
(
raise
NotImplementedError
(
"Speculative decoding with draft model is not "
f
"Unsupported speculative method: '
{
self
.
method
}
'"
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or mtp."
)
)
# Replace hf_config for EAGLE draft_model
# Replace hf_config for EAGLE draft_model
...
@@ -631,6 +632,12 @@ class SpeculativeConfig:
...
@@ -631,6 +632,12 @@ class SpeculativeConfig:
@
model_validator
(
mode
=
"after"
)
@
model_validator
(
mode
=
"after"
)
def
_verify_args
(
self
)
->
Self
:
def
_verify_args
(
self
)
->
Self
:
if
self
.
tensor_parallel_size
is
not
None
:
raise
ValueError
(
"'tensor_parallel_size' is not a valid argument in the "
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
)
if
self
.
num_speculative_tokens
is
None
:
if
self
.
num_speculative_tokens
is
None
:
raise
ValueError
(
raise
ValueError
(
"num_speculative_tokens must be provided with "
"num_speculative_tokens must be provided with "
...
@@ -669,12 +676,32 @@ class SpeculativeConfig:
...
@@ -669,12 +676,32 @@ class SpeculativeConfig:
f
"Eagle3 is only supported for
{
eagle3_target_supported
}
models. "
# noqa: E501
f
"Eagle3 is only supported for
{
eagle3_target_supported
}
models. "
# noqa: E501
f
"Got
{
self
.
target_model_config
.
hf_text_config
.
model_type
=
}
"
f
"Got
{
self
.
target_model_config
.
hf_text_config
.
model_type
=
}
"
)
)
self
.
verify_equal_vocab_size_if_draft_model
()
return
self
return
self
def
verify_equal_vocab_size_if_draft_model
(
self
):
if
(
self
.
method
==
"draft_model"
and
self
.
target_model_config
is
not
None
and
self
.
draft_model_config
is
not
None
):
target_vocab_size
=
self
.
target_model_config
.
get_vocab_size
()
draft_vocab_size
=
self
.
draft_model_config
.
get_vocab_size
()
if
target_vocab_size
!=
draft_vocab_size
:
raise
ValueError
(
f
"Target and draft model should have the same vocabulary size. "
f
"Target model vocab_size=
{
target_vocab_size
}
. "
f
"Draft model vocab_size=
{
draft_vocab_size
}
. "
f
"Using models with different tokenizers can cause out-of-bounds "
f
"errors during speculative decoding."
)
def
use_eagle
(
self
)
->
bool
:
def
use_eagle
(
self
)
->
bool
:
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
)
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
)
def
uses_draft_model
(
self
)
->
bool
:
return
self
.
method
==
"draft_model"
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
method
=
self
.
method
model
=
None
if
method
in
(
"ngram"
,
"suffix"
)
else
self
.
draft_model_config
.
model
model
=
None
if
method
in
(
"ngram"
,
"suffix"
)
else
self
.
draft_model_config
.
model
...
...
vllm/config/vllm.py
View file @
4a5299c9
...
@@ -1214,10 +1214,19 @@ class VllmConfig:
...
@@ -1214,10 +1214,19 @@ class VllmConfig:
compilation_config
=
self
.
compilation_config
compilation_config
=
self
.
compilation_config
computed_compile_ranges_split_points
=
[]
computed_compile_ranges_split_points
=
[]
# The upper bound of the compile ranges is the max_num_batched_tokens
# The upper bound of the compile ranges is the max_num_batched_tokens.
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
# For speculative decoding with draft model, the compile range must be extended
if
max_num_batched_tokens
is
not
None
:
# by 1 for each sequence.
computed_compile_ranges_split_points
.
append
(
max_num_batched_tokens
)
compile_range_end
=
self
.
scheduler_config
.
max_num_batched_tokens
if
compile_range_end
is
not
None
:
do_extend
:
bool
=
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
uses_draft_model
()
)
if
do_extend
:
compile_range_end
+=
self
.
scheduler_config
.
max_num_seqs
computed_compile_ranges_split_points
.
append
(
compile_range_end
)
# Add the compile ranges for flashinfer
# Add the compile ranges for flashinfer
if
compilation_config
.
pass_config
.
fuse_allreduce_rms
:
if
compilation_config
.
pass_config
.
fuse_allreduce_rms
:
...
@@ -1228,10 +1237,7 @@ class VllmConfig:
...
@@ -1228,10 +1237,7 @@ class VllmConfig:
self
.
model_config
.
get_hidden_size
()
self
.
model_config
.
get_hidden_size
()
*
self
.
model_config
.
dtype
.
itemsize
*
self
.
model_config
.
dtype
.
itemsize
)
)
if
(
if
compile_range_end
is
not
None
and
max_token_num
<
compile_range_end
:
max_num_batched_tokens
is
not
None
and
max_token_num
<
max_num_batched_tokens
):
computed_compile_ranges_split_points
.
append
(
max_token_num
)
computed_compile_ranges_split_points
.
append
(
max_token_num
)
else
:
else
:
logger
.
debug
(
logger
.
debug
(
...
@@ -1243,11 +1249,7 @@ class VllmConfig:
...
@@ -1243,11 +1249,7 @@ class VllmConfig:
for
x
in
compilation_config
.
compile_ranges_split_points
:
for
x
in
compilation_config
.
compile_ranges_split_points
:
assert
isinstance
(
x
,
int
)
assert
isinstance
(
x
,
int
)
assert
x
>
0
,
f
"Invalid compile range split point:
{
x
}
"
assert
x
>
0
,
f
"Invalid compile range split point:
{
x
}
"
if
(
if
compile_range_end
is
not
None
and
x
<
compile_range_end
and
x
>
1
:
max_num_batched_tokens
is
not
None
and
x
<
max_num_batched_tokens
and
x
>
1
):
computed_compile_ranges_split_points
.
append
(
x
)
computed_compile_ranges_split_points
.
append
(
x
)
compilation_config
.
compile_ranges_split_points
=
sorted
(
compilation_config
.
compile_ranges_split_points
=
sorted
(
computed_compile_ranges_split_points
computed_compile_ranges_split_points
...
@@ -1316,6 +1318,14 @@ class VllmConfig:
...
@@ -1316,6 +1318,14 @@ class VllmConfig:
path
=
self
.
compilation_config
.
debug_dump_path
/
append_path
path
=
self
.
compilation_config
.
debug_dump_path
/
append_path
return
path
return
path
def
replace
(
self
,
**
kwargs
):
"""
Replace attributes of the config, and 'recompute' the config.
dataclass.replace() calls __init__() and __post_init__(), source:
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
"""
return
replace
(
self
,
**
kwargs
)
def
__str__
(
self
):
def
__str__
(
self
):
return
(
return
(
f
"model=
{
self
.
model_config
.
model
!
r
}
, "
f
"model=
{
self
.
model_config
.
model
!
r
}
, "
...
...
vllm/engine/arg_utils.py
View file @
4a5299c9
...
@@ -1776,21 +1776,6 @@ class EngineArgs:
...
@@ -1776,21 +1776,6 @@ class EngineArgs:
):
):
_raise_unsupported_error
(
feature_name
=
"Concurrent Partial Prefill"
)
_raise_unsupported_error
(
feature_name
=
"Concurrent Partial Prefill"
)
# N-gram, Medusa, and Eagle are supported for speculative decoding.
if
self
.
speculative_config
is
not
None
:
# speculative_config could still be a dict at this point
if
isinstance
(
self
.
speculative_config
,
dict
):
method
=
self
.
speculative_config
.
get
(
"method"
,
None
)
else
:
method
=
self
.
speculative_config
.
method
if
method
==
"draft_model"
:
raise
NotImplementedError
(
"Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or mtp."
)
if
self
.
pipeline_parallel_size
>
1
:
if
self
.
pipeline_parallel_size
>
1
:
supports_pp
=
getattr
(
supports_pp
=
getattr
(
self
.
distributed_executor_backend
,
"supports_pp"
,
False
self
.
distributed_executor_backend
,
"supports_pp"
,
False
...
...
vllm/model_executor/model_loader/__init__.py
View file @
4a5299c9
...
@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
...
@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
def
get_model
(
def
get_model
(
*
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
|
None
=
None
*
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
loader
=
get_model_loader
(
vllm_config
.
load_config
)
loader
=
get_model_loader
(
vllm_config
.
load_config
)
if
model_config
is
None
:
if
model_config
is
None
:
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
return
loader
.
load_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
)
return
loader
.
load_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
,
prefix
=
prefix
)
__all__
=
[
__all__
=
[
...
...
vllm/model_executor/model_loader/base_loader.py
View file @
4a5299c9
...
@@ -36,7 +36,7 @@ class BaseModelLoader(ABC):
...
@@ -36,7 +36,7 @@ class BaseModelLoader(ABC):
raise
NotImplementedError
raise
NotImplementedError
def
load_model
(
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
,
prefix
:
str
=
""
)
->
nn
.
Module
:
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
"""Load a model with the given configurations."""
device_config
=
vllm_config
.
device_config
device_config
=
vllm_config
.
device_config
...
@@ -48,7 +48,7 @@ class BaseModelLoader(ABC):
...
@@ -48,7 +48,7 @@ class BaseModelLoader(ABC):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
with
target_device
:
model
=
initialize_model
(
model
=
initialize_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
vllm_config
=
vllm_config
,
model_config
=
model_config
,
prefix
=
prefix
)
)
log_model_inspection
(
model
)
log_model_inspection
(
model
)
...
...
vllm/model_executor/model_loader/gguf_loader.py
View file @
4a5299c9
...
@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader):
)
)
def
load_model
(
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
,
prefix
:
str
=
""
)
->
nn
.
Module
:
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
device_config
=
vllm_config
.
device_config
local_model_path
=
self
.
_prepare_weights
(
model_config
)
local_model_path
=
self
.
_prepare_weights
(
model_config
)
...
@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader):
target_device
=
torch
.
device
(
device_config
.
device
)
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model
=
initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
load_weights
(
model
,
model_config
)
self
.
load_weights
(
model
,
model_config
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
...
...
vllm/model_executor/model_loader/tensorizer_loader.py
View file @
4a5299c9
...
@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader):
def
_load_model_serialized_cpu
(
def
_load_model_serialized_cpu
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer to the CPU.
"""Load a serialized model with tensorizer to the CPU.
...
@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader):
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model
=
initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
return
model
.
eval
()
...
@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader):
model
.
load_weights
(
self
.
_get_weights_iterator
())
model
.
load_weights
(
self
.
_get_weights_iterator
())
def
load_model
(
def
load_model
(
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
self
,
vllm_config
:
VllmConfig
,
model_config
:
ModelConfig
,
prefix
:
str
=
""
)
->
nn
.
Module
:
)
->
nn
.
Module
:
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
self
.
_verify_config
(
model_config
,
parallel_config
)
self
.
_verify_config
(
model_config
,
parallel_config
)
...
@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader):
)
)
self
.
load_weights
(
model
,
model_config
)
self
.
load_weights
(
model
,
model_config
)
return
model
return
model
return
self
.
_load_model_serialized_cpu
(
vllm_config
=
vllm_config
)
return
self
.
_load_model_serialized_cpu
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
@
staticmethod
@
staticmethod
def
save_model
(
def
save_model
(
...
...
vllm/v1/attention/backend.py
View file @
4a5299c9
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
replace
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Protocol
,
TypeVar
,
get_args
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Protocol
,
TypeVar
,
get_args
...
@@ -329,6 +329,16 @@ class CommonAttentionMetadata:
...
@@ -329,6 +329,16 @@ class CommonAttentionMetadata:
_num_computed_tokens_cache
:
torch
.
Tensor
|
None
=
None
_num_computed_tokens_cache
:
torch
.
Tensor
|
None
=
None
def
batch_size
(
self
)
->
int
:
return
self
.
seq_lens
.
shape
[
0
]
def
naive_query_lens
(
self
)
->
torch
.
Tensor
:
"""Naive because it assumes that query ends where the next query starts."""
return
self
.
query_start_loc
[
1
:]
-
self
.
query_start_loc
[:
-
1
]
def
replace
(
self
,
**
kwargs
)
->
"CommonAttentionMetadata"
:
return
replace
(
self
,
**
kwargs
)
@
property
@
property
@
deprecated
(
@
deprecated
(
"""
"""
...
...
vllm/v1/attention/backends/utils.py
View file @
4a5299c9
...
@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens(
...
@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens(
)
)
dcp_local_seq_lens
=
base
+
remainder
dcp_local_seq_lens
=
base
+
remainder
return
dcp_local_seq_lens
.
squeeze
(
1
)
return
dcp_local_seq_lens
.
squeeze
(
1
)
def
extend_all_queries_by_1
(
common_attn_metadata
:
CommonAttentionMetadata
,
arange
:
torch
.
Tensor
,
new_slot_mapping
:
torch
.
Tensor
,
)
->
CommonAttentionMetadata
:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
Also all seq lens are increased by 1.
This is useful e.g. in speculative decoding with draft models, where we
extend each sequence by 1 token.
The slot mapping is computed externally, as it requires more information.
"""
cad
=
common_attn_metadata
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
new_query_start_loc
=
cad
.
query_start_loc
+
arange
[:
len
(
cad
.
query_start_loc
)]
new_query_start_loc_cpu
=
cad
.
query_start_loc_cpu
+
torch
.
arange
(
len
(
cad
.
query_start_loc_cpu
),
dtype
=
torch
.
int32
)
new_cad
=
cad
.
replace
(
query_start_loc
=
new_query_start_loc
,
query_start_loc_cpu
=
new_query_start_loc_cpu
,
seq_lens
=
cad
.
seq_lens
+
1
,
# each request is extended by 1 token -> batch_size tokens are added
num_actual_tokens
=
cad
.
num_actual_tokens
+
cad
.
batch_size
(),
# All query lens increase by 1, so max query len increases by 1
max_query_len
=
cad
.
max_query_len
+
1
,
max_seq_len
=
cad
.
max_seq_len
+
1
,
slot_mapping
=
new_slot_mapping
,
)
return
new_cad
vllm/v1/core/sched/scheduler.py
View file @
4a5299c9
...
@@ -208,6 +208,8 @@ class Scheduler(SchedulerInterface):
...
@@ -208,6 +208,8 @@ class Scheduler(SchedulerInterface):
if
speculative_config
.
use_eagle
():
if
speculative_config
.
use_eagle
():
self
.
use_eagle
=
True
self
.
use_eagle
=
True
self
.
num_lookahead_tokens
=
self
.
num_spec_tokens
self
.
num_lookahead_tokens
=
self
.
num_spec_tokens
if
speculative_config
.
uses_draft_model
():
self
.
num_lookahead_tokens
=
self
.
num_spec_tokens
# Create the KV cache manager.
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
self
.
kv_cache_manager
=
KVCacheManager
(
...
...
vllm/v1/spec_decode/draft_model.py
0 → 100644
View file @
4a5299c9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config.speculative
import
SpeculativeConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
extend_all_queries_by_1
,
)
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
SpecDecodeBaseProposer
logger
=
init_logger
(
__name__
)
class
DraftModelProposer
(
SpecDecodeBaseProposer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
,
):
super
().
__init__
(
vllm_config
=
vllm_config
,
device
=
device
,
pass_hidden_states_to_model
=
False
,
runner
=
runner
,
)
self
.
_raise_if_multimodal
()
self
.
_raise_if_mrope
()
self
.
_raise_if_padded_drafter_batch_disabled
()
self
.
_raise_if_vocab_size_mismatch
()
self
.
_raise_if_draft_tp_mismatch
()
def
_block_size
(
self
)
->
int
:
builder
=
self
.
_get_attention_metadata_builder
()
return
builder
.
kv_cache_spec
.
block_size
def
_raise_if_multimodal
(
self
):
if
self
.
supports_mm_inputs
:
raise
NotImplementedError
(
"Speculative Decoding with draft models "
"does not support multimodal models yet"
)
def
_raise_if_mrope
(
self
):
if
self
.
draft_model_config
.
uses_mrope
:
raise
NotImplementedError
(
"Speculative Decoding with draft models does not support M-RoPE yet"
)
def
_raise_if_padded_drafter_batch_disabled
(
self
):
if
self
.
vllm_config
.
speculative_config
.
disable_padded_drafter_batch
:
raise
NotImplementedError
(
"Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
" in the speculative_config."
)
def
_raise_if_vocab_size_mismatch
(
self
):
self
.
vllm_config
.
speculative_config
.
verify_equal_vocab_size_if_draft_model
()
def
_raise_if_draft_tp_mismatch
(
self
):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
# the draft model with TP = 1, then the different TP ranks collide.
# Specifically when all ranks compile the draft model on rank 0
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg
:
SpeculativeConfig
=
self
.
vllm_config
.
speculative_config
tgt_tp
=
spec_cfg
.
target_parallel_config
.
tensor_parallel_size
draft_tp
=
spec_cfg
.
draft_parallel_config
.
tensor_parallel_size
if
draft_tp
!=
tgt_tp
:
raise
ValueError
(
f
"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
f
"must be the same. Got
{
draft_tp
}
and
{
tgt_tp
}
. "
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)
def
set_inputs_first_pass
(
self
,
target_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
batch_size
=
cad
.
batch_size
()
grid
=
(
batch_size
,)
start_locs
=
cad
.
query_start_loc
[:
-
1
]
end_locs
=
cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
end_locs
-=
num_rejected_tokens_gpu
num_tokens
=
target_token_ids
.
shape
[
0
]
+
batch_size
is_rejected_tok
=
torch
.
empty
(
(
num_tokens
,),
device
=
self
.
input_ids
.
device
,
dtype
=
torch
.
bool
)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
target_token_ids
,
next_toks_ptr
=
next_token_ids
,
query_start_locs_ptr
=
start_locs
,
query_end_locs_ptr
=
end_locs
,
out_ptr_merged_toks
=
self
.
input_ids
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
target_token_ids
.
shape
[
0
],
# passing a negative rejected_tok_fill value will raise an error
# when the value is used to index into embeddings.
# Therefore, we pass a valid integer, e.g. 0.
rejected_tok_fill
=
0
,
)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
target_positions
,
next_toks_ptr
=
target_positions
[
end_locs
]
+
1
,
query_start_locs_ptr
=
start_locs
,
query_end_locs_ptr
=
end_locs
,
out_ptr_merged_toks
=
self
.
positions
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
target_positions
.
shape
[
0
],
rejected_tok_fill
=
0
,
)
# recompute slot mapping
new_slot_mapping
=
compute_new_slot_mapping
(
cad
=
cad
,
new_positions
=
self
.
positions
[:
num_tokens
],
is_rejected_token_mask
=
is_rejected_tok
,
block_size
=
self
.
_block_size
(),
max_model_len
=
self
.
max_model_len
,
)
# update common_attn_metadata
new_cad
:
CommonAttentionMetadata
=
extend_all_queries_by_1
(
cad
,
arange
=
self
.
arange
,
new_slot_mapping
=
new_slot_mapping
,
)
new_last_token_indices
=
new_cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
new_last_token_indices
-=
num_rejected_tokens_gpu
return
num_tokens
,
new_last_token_indices
,
new_cad
def
load_model
(
self
,
target_model
:
Any
)
->
None
:
"""Takes target_model to satisfy the type checker."""
# This must be computed before loading the draft model
# because that mutates the forward_context of the vllm_config
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
)
from
vllm.compilation.backends
import
set_model_tag
draft_vllm_config
:
VllmConfig
=
create_vllm_config_for_draft_model
(
target_model_vllm_config
=
self
.
vllm_config
)
logger
.
info
(
"Starting to load draft model %s. TP=%d, rank=%d"
,
draft_vllm_config
.
model_config
.
model
,
draft_vllm_config
.
parallel_config
.
tensor_parallel_size
,
draft_vllm_config
.
parallel_config
.
rank
,
)
with
set_model_tag
(
"draft_model"
):
self
.
model
=
get_model
(
vllm_config
=
draft_vllm_config
,
prefix
=
"draft_model"
)
# This must be computed after loading the draft model
# because that mutates the forward_context of the vllm_config
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
def
create_vllm_config_for_draft_model
(
target_model_vllm_config
:
VllmConfig
,
)
->
VllmConfig
:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the draft model.
The vllm_config is useful when loading the draft model with get_model().
"""
old
=
target_model_vllm_config
new_parallel_config
=
old
.
speculative_config
.
draft_parallel_config
.
replace
(
rank
=
old
.
parallel_config
.
rank
)
new
:
VllmConfig
=
old
.
replace
(
quant_config
=
None
,
# quant_config is recomputed in __init__()
model_config
=
old
.
speculative_config
.
draft_model_config
,
parallel_config
=
new_parallel_config
,
)
return
new
def
compute_new_slot_mapping
(
cad
:
CommonAttentionMetadata
,
new_positions
:
torch
.
Tensor
,
is_rejected_token_mask
:
torch
.
Tensor
,
block_size
:
int
,
max_model_len
:
int
,
):
batch_size
,
n_blocks_per_req
=
cad
.
block_table_tensor
.
shape
req_indices
=
torch
.
arange
(
batch_size
,
device
=
cad
.
query_start_loc
.
device
)
req_indices
=
torch
.
repeat_interleave
(
req_indices
,
cad
.
naive_query_lens
()
+
1
,
output_size
=
len
(
new_positions
)
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions
=
torch
.
clamp
(
new_positions
,
max
=
max_model_len
-
1
)
block_table_indices
=
(
req_indices
*
n_blocks_per_req
+
clamped_positions
//
block_size
)
block_nums
=
cad
.
block_table_tensor
.
view
(
-
1
)[
block_table_indices
]
block_offsets
=
clamped_positions
%
block_size
new_slot_mapping
=
block_nums
*
block_size
+
block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len
=
new_positions
>=
max_model_len
new_slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping
.
masked_fill_
(
is_rejected_token_mask
,
PADDING_SLOT_ID
)
return
new_slot_mapping
@
triton
.
jit
def
merge_toks_kernel
(
target_toks_ptr
,
next_toks_ptr
,
query_start_locs_ptr
,
query_end_locs_ptr
,
out_ptr_merged_toks
,
out_ptr_is_rejected_tok
,
target_toks_size
,
rejected_tok_fill
,
):
"""
Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor
called `out_ptr_merged_toks`. Rejected tokens are those after the
`query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the
rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask
of the rejected tokens in `out_ptr_is_rejected_tok`.
"""
pid
=
tl
.
program_id
(
0
)
start_loc
=
tl
.
load
(
query_start_locs_ptr
+
pid
)
is_last_program
=
pid
==
tl
.
num_programs
(
0
)
-
1
if
is_last_program
:
next_start_loc
=
target_toks_size
.
to
(
tl
.
int32
)
else
:
next_start_loc
=
tl
.
load
(
query_start_locs_ptr
+
pid
+
1
).
to
(
tl
.
int32
)
end_loc
=
tl
.
load
(
query_end_locs_ptr
+
pid
)
new_val
=
tl
.
load
(
next_toks_ptr
+
pid
)
for
i
in
range
(
start_loc
,
next_start_loc
+
1
):
if
i
<=
end_loc
:
# copy existing tokens
old_val
=
tl
.
load
(
target_toks_ptr
+
i
)
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
old_val
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
False
)
elif
i
==
end_loc
+
1
:
# copy bonus token
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
new_val
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
False
)
else
:
# fill rejected tokens
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
rejected_tok_fill
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
True
)
vllm/v1/spec_decode/eagle.py
View file @
4a5299c9
...
@@ -53,11 +53,12 @@ logger = init_logger(__name__)
...
@@ -53,11 +53,12 @@ logger = init_logger(__name__)
PADDING_SLOT_ID
=
-
1
PADDING_SLOT_ID
=
-
1
class
Eagl
eProposer
:
class
SpecDecodeBas
eProposer
:
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
device
:
torch
.
device
,
pass_hidden_states_to_model
:
bool
,
runner
=
None
,
runner
=
None
,
):
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
...
@@ -65,6 +66,7 @@ class EagleProposer:
...
@@ -65,6 +66,7 @@ class EagleProposer:
assert
self
.
speculative_config
is
not
None
assert
self
.
speculative_config
is
not
None
self
.
draft_model_config
=
self
.
speculative_config
.
draft_model_config
self
.
draft_model_config
=
self
.
speculative_config
.
draft_model_config
self
.
method
=
self
.
speculative_config
.
method
self
.
method
=
self
.
speculative_config
.
method
self
.
pass_hidden_states_to_model
=
pass_hidden_states_to_model
self
.
runner
=
runner
self
.
runner
=
runner
self
.
device
=
device
self
.
device
=
device
...
@@ -72,7 +74,11 @@ class EagleProposer:
...
@@ -72,7 +74,11 @@ class EagleProposer:
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
self
.
token_arange_np
=
np
.
arange
(
self
.
max_num_tokens
)
self
.
token_arange_np
=
np
.
arange
(
self
.
max_num_tokens
)
# We need to get the hidden size from the draft model config because
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# the draft model's hidden size can be different from the target model's
...
@@ -143,7 +149,6 @@ class EagleProposer:
...
@@ -143,7 +149,6 @@ class EagleProposer:
# We need +1 here because the arange is used to set query_start_loc,
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
# which has one more element than batch_size.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_slots_for_arange
=
max
(
max_batch_size
+
1
,
self
.
max_num_tokens
)
max_num_slots_for_arange
=
max
(
max_batch_size
+
1
,
self
.
max_num_tokens
)
self
.
arange
=
torch
.
arange
(
self
.
arange
=
torch
.
arange
(
max_num_slots_for_arange
,
device
=
device
,
dtype
=
torch
.
int32
max_num_slots_for_arange
,
device
=
device
,
dtype
=
torch
.
int32
...
@@ -245,11 +250,7 @@ class EagleProposer:
...
@@ -245,11 +250,7 @@ class EagleProposer:
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
=
None
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
common_attn_metadata
.
batch_size
()
batch_size
=
next_token_ids
.
shape
[
0
]
if
last_token_indices
is
None
:
last_token_indices
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
if
self
.
method
==
"eagle3"
:
if
self
.
method
==
"eagle3"
:
assert
isinstance
(
self
.
model
,
Eagle3LlamaForCausalLM
)
assert
isinstance
(
self
.
model
,
Eagle3LlamaForCausalLM
)
...
@@ -257,12 +258,17 @@ class EagleProposer:
...
@@ -257,12 +258,17 @@ class EagleProposer:
target_hidden_states
target_hidden_states
)
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
num_tokens
,
last_token_indices
,
common_attn_metadata
=
(
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
self
.
set_inputs_first_pass
(
# Replace the last token with the next token.
target_token_ids
=
target_token_ids
,
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
next_token_ids
=
next_token_ids
,
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
target_positions
=
target_positions
,
last_token_indices
=
last_token_indices
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
)
)
assert
self
.
runner
is
not
None
assert
self
.
runner
is
not
None
...
@@ -311,9 +317,10 @@ class EagleProposer:
...
@@ -311,9 +317,10 @@ class EagleProposer:
if
num_tokens_across_dp
is
not
None
:
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
# copy inputs to buffer for cudagraph
if
self
.
pass_hidden_states_to_model
:
self
.
_set_positions
(
num_tokens
,
target_positions
)
# target_hidden_states and self.hidden_states can have different
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
# hidden dims. E.g. large target model and small draft model.
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
...
@@ -330,6 +337,14 @@ class EagleProposer:
...
@@ -330,6 +337,14 @@ class EagleProposer:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
inputs_embeds
=
None
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
with
set_forward_context
(
with
set_forward_context
(
per_layer_attn_metadata
,
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
...
@@ -337,17 +352,13 @@ class EagleProposer:
...
@@ -337,17 +352,13 @@ class EagleProposer:
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
):
):
ret_hidden_states
=
self
.
model
(
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
input_ids
=
input_ids
,
if
not
self
.
model_returns_tuple
():
positions
=
self
.
_get_positions
(
num_input_tokens
),
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
inputs_embeds
=
inputs_embeds
,
)
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
last_hidden_states
=
ret_hidden_states
hidden_states
=
last_hidden_states
hidden_states
=
last_hidden_states
else
:
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
...
@@ -357,9 +368,9 @@ class EagleProposer:
...
@@ -357,9 +368,9 @@ class EagleProposer:
return
draft_token_ids
.
view
(
-
1
,
1
)
return
draft_token_ids
.
view
(
-
1
,
1
)
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
positions
=
target_
positions
[:,
last_token_indices
]
positions
=
self
.
positions
[:,
last_token_indices
]
else
:
else
:
positions
=
target_
positions
[
last_token_indices
]
positions
=
self
.
positions
[
last_token_indices
]
if
self
.
method
in
(
if
self
.
method
in
(
"deepseek_mtp"
,
"deepseek_mtp"
,
"ernie_mtp"
,
"ernie_mtp"
,
...
@@ -527,6 +538,14 @@ class EagleProposer:
...
@@ -527,6 +538,14 @@ class EagleProposer:
inputs_embeds
=
None
inputs_embeds
=
None
# Run the model.
# Run the model.
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
input_batch_size
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
input_batch_size
]
with
set_forward_context
(
with
set_forward_context
(
per_layer_attn_metadata
,
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
...
@@ -534,17 +553,13 @@ class EagleProposer:
...
@@ -534,17 +553,13 @@ class EagleProposer:
num_tokens_across_dp
=
batch_size_across_dp
,
num_tokens_across_dp
=
batch_size_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
):
):
ret_hidden_states
=
self
.
model
(
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
input_ids
=
input_ids
,
if
not
self
.
model_returns_tuple
():
positions
=
self
.
_get_positions
(
input_batch_size
),
hidden_states
=
self
.
hidden_states
[:
input_batch_size
],
inputs_embeds
=
inputs_embeds
,
)
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
last_hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
else
:
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
last_hidden_states
,
hidden_states
=
ret_hidden_states
hidden_states
=
hidden_states
[:
batch_size
]
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
])
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
])
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
...
@@ -554,6 +569,34 @@ class EagleProposer:
...
@@ -554,6 +569,34 @@ class EagleProposer:
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
return
draft_token_ids
def
set_inputs_first_pass
(
self
,
target_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
if
last_token_indices
is
None
:
last_token_indices
=
cad
.
query_start_loc
[
1
:]
-
1
num_tokens
=
target_token_ids
.
shape
[
0
]
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
# copy inputs to buffer for cudagraph
self
.
_set_positions
(
num_tokens
,
target_positions
)
return
num_tokens
,
last_token_indices
,
cad
def
model_returns_tuple
(
self
)
->
bool
:
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
)
def
prepare_next_token_ids_cpu
(
def
prepare_next_token_ids_cpu
(
self
,
self
,
sampled_token_ids
:
list
[
list
[
int
]],
sampled_token_ids
:
list
[
list
[
int
]],
...
@@ -1214,12 +1257,14 @@ class EagleProposer:
...
@@ -1214,12 +1257,14 @@ class EagleProposer:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
inputs_embeds
=
None
self
.
model
(
kwargs
=
dict
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
self
.
_get_positions
(
num_input_tokens
),
positions
=
self
.
_get_positions
(
num_input_tokens
),
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
if
self
.
pass_hidden_states_to_model
:
kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
self
.
model
(
**
kwargs
)
def
_get_attention_metadata_builder
(
self
)
->
AttentionMetadataBuilder
:
def
_get_attention_metadata_builder
(
self
)
->
AttentionMetadataBuilder
:
"""Find and return the attention metadata builders for EAGLE layers.
"""Find and return the attention metadata builders for EAGLE layers.
...
@@ -1264,8 +1309,8 @@ class EagleProposer:
...
@@ -1264,8 +1309,8 @@ class EagleProposer:
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
"""
Validate that all
eagle
layers belong to the same KVCacheGroup.
Validate that all
drafting
layers belong to the same KVCacheGroup.
Need this assumption to ensure all
eagle
layers can use the
Need this assumption to ensure all
drafting
layers can use the
same AttentionMetadata.
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
May extend to multiple AttentionMetadata in the future.
"""
"""
...
@@ -1283,7 +1328,7 @@ class EagleProposer:
...
@@ -1283,7 +1328,7 @@ class EagleProposer:
)
)
)
)
==
1
==
1
),
"All
eagle
layers should belong to the same kv cache group"
),
"All
drafting
layers should belong to the same kv cache group"
def
_pad_batch_across_dp
(
def
_pad_batch_across_dp
(
self
,
self
,
...
@@ -1308,6 +1353,21 @@ class EagleProposer:
...
@@ -1308,6 +1353,21 @@ class EagleProposer:
return
num_tokens_dp_padded
,
num_toks_across_dp
return
num_tokens_dp_padded
,
num_toks_across_dp
class
EagleProposer
(
SpecDecodeBaseProposer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
,
):
super
().
__init__
(
vllm_config
,
device
,
pass_hidden_states_to_model
=
True
,
runner
=
runner
,
)
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# the draft prob tensor.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4a5299c9
...
@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
...
@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
...
@@ -432,10 +433,20 @@ class GPUModelRunner(
...
@@ -432,10 +433,20 @@ class GPUModelRunner(
# layers in the draft model.
# layers in the draft model.
if
self
.
speculative_config
and
get_pp_group
().
is_last_rank
:
if
self
.
speculative_config
and
get_pp_group
().
is_last_rank
:
self
.
drafter
:
(
self
.
drafter
:
(
NgramProposer
|
SuffixDecodingProposer
|
EagleProposer
|
MedusaProposer
NgramProposer
|
SuffixDecodingProposer
|
EagleProposer
|
DraftModelProposer
|
MedusaProposer
)
)
if
self
.
speculative_config
.
method
==
"ngram"
:
if
self
.
speculative_config
.
method
==
"ngram"
:
self
.
drafter
=
NgramProposer
(
self
.
vllm_config
)
self
.
drafter
=
NgramProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
uses_draft_model
():
self
.
drafter
=
DraftModelProposer
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
,
runner
=
self
,
)
elif
self
.
speculative_config
.
method
==
"suffix"
:
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
elif
self
.
speculative_config
.
use_eagle
():
...
@@ -3443,10 +3454,13 @@ class GPUModelRunner(
...
@@ -3443,10 +3454,13 @@ class GPUModelRunner(
spec_decode_common_attn_metadata
.
max_seq_len
+
self
.
num_spec_tokens
spec_decode_common_attn_metadata
.
max_seq_len
+
self
.
num_spec_tokens
<=
self
.
effective_drafter_max_model_len
<=
self
.
effective_drafter_max_model_len
)
)
if
spec_config
.
use_eagle
()
and
not
spec_config
.
disable_padded_drafter_batch
:
use_gpu_toks
=
(
# EAGLE speculative decoding can use the GPU sampled tokens
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
()
)
and
not
spec_config
.
disable_padded_drafter_batch
if
use_gpu_toks
:
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
# as inputs, and does not need to wait for bookkeeping to finish.
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
sampled_token_ids
=
sampler_output
.
sampled_token_ids
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
input_fits_in_drafter
:
if
input_fits_in_drafter
:
propose_draft_token_ids
(
sampled_token_ids
)
propose_draft_token_ids
(
sampled_token_ids
)
...
@@ -3679,8 +3693,8 @@ class GPUModelRunner(
...
@@ -3679,8 +3693,8 @@ class GPUModelRunner(
target_hidden_states
=
hidden_states
,
target_hidden_states
=
hidden_states
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
elif
spec_config
.
use_eagle
():
elif
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
()
:
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
if
spec_config
.
disable_padded_drafter_batch
:
if
spec_config
.
disable_padded_drafter_batch
:
# When padded-batch is disabled, the sampled_token_ids should be
# When padded-batch is disabled, the sampled_token_ids should be
...
@@ -4475,8 +4489,12 @@ class GPUModelRunner(
...
@@ -4475,8 +4489,12 @@ class GPUModelRunner(
else
:
else
:
hidden_states
=
outputs
hidden_states
=
outputs
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
if
self
.
speculative_config
and
(
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_draft_model
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
assert
self
.
speculative_config
is
not
None
# Eagle currently only supports PIECEWISE cudagraphs.
# Eagle currently only supports PIECEWISE cudagraphs.
# Therefore only use cudagraphs if the main model uses PIECEWISE
# Therefore only use cudagraphs if the main model uses PIECEWISE
# NOTE(lucas): this is a hack, need to clean up.
# NOTE(lucas): this is a hack, need to clean up.
...
@@ -5652,8 +5670,11 @@ class GPUModelRunner(
...
@@ -5652,8 +5670,11 @@ class GPUModelRunner(
kv_cache_config
,
kernel_block_sizes
kv_cache_config
,
kernel_block_sizes
)
)
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
if
self
.
speculative_config
and
(
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_draft_model
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
# validate all draft model layers belong to the same kv cache
# validate all draft model layers belong to the same kv cache
# group
# group
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment