Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3fce56b
Unverified
Commit
a3fce56b
authored
Aug 22, 2024
by
Abhinav Goyal
Committed by
GitHub
Aug 22, 2024
Browse files
[Speculative Decoding] EAGLE Implementation with Top-1 proposer (#6830)
parent
b3856bef
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
854 additions
and
83 deletions
+854
-83
tests/spec_decode/e2e/test_eagle_correctness.py
tests/spec_decode/e2e/test_eagle_correctness.py
+268
-0
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+56
-12
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+35
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/eagle.py
vllm/model_executor/models/eagle.py
+161
-0
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+19
-0
vllm/sequence.py
vllm/sequence.py
+60
-9
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+19
-0
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+8
-2
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+72
-25
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+6
-4
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+49
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+34
-10
vllm/worker/multi_step_worker.py
vllm/worker/multi_step_worker.py
+13
-8
vllm/worker/worker.py
vllm/worker/worker.py
+1
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+50
-11
No files found.
tests/spec_decode/e2e/test_eagle_correctness.py
0 → 100644
View file @
a3fce56b
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, EAGLE would not break the
correctess for the target model outputs.
"""
import
pytest
from
.conftest
import
run_greedy_equality_correctness_test
# main model
MAIN_MODEL
=
"JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"abhigoyal/vllm-eagle-llama-68m-random"
# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
4
# precision
PRECISION
=
"float32"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"enforce_eager"
:
False
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_e2e_greedy_correctness_cuda_graph
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
8
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
k
,
}
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_disable_by_batch_size"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
if
__name__
==
"__main__"
:
import
pytest
pytest
.
main
([
__file__
])
tests/spec_decode/e2e/test_medusa_correctness.py
View file @
a3fce56b
...
@@ -70,8 +70,9 @@ PRECISION = "float32"
...
@@ -70,8 +70,9 @@ PRECISION = "float32"
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
def
test_medusa_e2e_greedy_correctness
(
baseline_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality with different batch size."""
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
test_llm_generator
,
...
@@ -80,6 +81,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
...
@@ -80,6 +81,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len
=
True
)
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"enforce_eager"
:
False
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_medusa_e2e_greedy_correctness_cuda_graph
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
@@ -116,7 +160,7 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
...
@@ -116,7 +160,7 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_m
lp
_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
def
test_m
edusa
_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
batch_size
:
int
,
output_len
:
int
):
output_len
:
int
):
...
@@ -165,9 +209,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
...
@@ -165,9 +209,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32
,
32
,
])
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_m
lp
_different_k
(
baseline_llm_generator
,
test_llm_generator
,
def
test_m
edusa
_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
batch_size
:
int
,
output_len
:
int
):
"""Verify that m
lp
speculative decoding produces exact equality
"""Verify that m
edusa
speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
to without spec decode with different values of num_speculative_tokens.
"""
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
...
@@ -208,9 +252,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
...
@@ -208,9 +252,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
32
,
32
,
])
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_m
lp
_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
def
test_m
edusa
_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
batch_size
:
int
,
output_len
:
int
):
"""Verify that m
lp
speculative decoding produces exact equality
"""Verify that m
edusa
speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
to without spec decode when speculation is disabled for large
batch sizes.
batch sizes.
"""
"""
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
a3fce56b
...
@@ -6,7 +6,8 @@ import pytest
...
@@ -6,7 +6,8 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
Logprob
,
SamplerOutput
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
Logprob
,
SamplerOutput
,
get_all_seq_ids
)
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
...
@@ -690,3 +691,36 @@ def test_use_draft_model_runner_advance_step():
...
@@ -690,3 +691,36 @@ def test_use_draft_model_runner_advance_step():
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
call_args_list
=
worker
.
model_runner
.
_gpu_advance_step
.
call_args_list
call_args_list
=
worker
.
model_runner
.
_gpu_advance_step
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
@
torch
.
inference_mode
()
def
test_expand_execute_model_request_sync_with_expand_hidden_states
():
"""
In this test we verify that the logic for expanding the
seq_group_metadata_list remains in sync with the expansion logic of
the HiddenStates in _expand_execute_model_request.
"""
k
=
5
batch_size
=
16
seq_with_bonus_token_in_last_step
=
[
1
,
3
,
8
,
10
,
13
,
15
]
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_request
=
ExecuteModelRequest
(
seq_group_metadata_list
,
previous_hidden_states
=
HiddenStates
(
torch
.
arange
(
batch_size
),
seq_group_metadata_list
,
torch
.
arange
(
batch_size
,
2
*
batch_size
)))
expanded_execute_model_request
,
orig_seq_group_ids
=
MultiStepWorker
.
\
_expand_execute_model_request
(
execute_model_request
,
seq_with_bonus_token_in_last_step
)
all_seq_ids
=
torch
.
tensor
(
get_all_seq_ids
(
expanded_execute_model_request
.
seq_group_metadata_list
))
ref_expanded_hidden_states
=
all_seq_ids
+
batch_size
ref_expanded_hidden_states
[
orig_seq_group_ids
]
-=
batch_size
assert
(
ref_expanded_hidden_states
==
expanded_execute_model_request
.
previous_hidden_states
.
hidden_states
).
all
().
item
()
vllm/model_executor/models/__init__.py
View file @
a3fce56b
...
@@ -60,6 +60,7 @@ _GENERATION_MODELS = {
...
@@ -60,6 +60,7 @@ _GENERATION_MODELS = {
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
"Phi3SmallForCausalLM"
:
(
"phi3_small"
,
"Phi3SmallForCausalLM"
),
"Phi3SmallForCausalLM"
:
(
"phi3_small"
,
"Phi3SmallForCausalLM"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"EAGLEModel"
:
(
"eagle"
,
"EAGLE"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
),
}
}
...
...
vllm/model_executor/models/eagle.py
0 → 100644
View file @
a3fce56b
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.transformers_utils.configs.eagle
import
EAGLEConfig
class
EAGLE
(
nn
.
Module
):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def
__init__
(
self
,
config
:
EAGLEConfig
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
config
=
config
architectures
=
getattr
(
self
.
config
.
model
,
"architectures"
,
[])
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
self
.
model
=
model_cls
(
self
.
config
.
model
,
*
args
,
**
kwargs
)
self
.
fc
=
nn
.
Linear
(
config
.
model
.
hidden_size
*
2
,
config
.
model
.
hidden_size
,
bias
=
False
)
self
.
orig_vocab_size
=
config
.
vocab_size
self
.
truncated_vocab_size
=
config
.
truncated_vocab_size
self
.
unpadded_vocab_size
=
self
.
truncated_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
self
.
truncated_vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
truncated_vocab_size
,
logit_scale
)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self
.
token_map
=
None
@
property
def
sampler
(
self
):
return
self
.
model
.
sampler
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
previous_hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
tok_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
fc
(
torch
.
cat
([
tok_embeds
,
previous_hidden_states
],
dim
=-
1
))
inputs_embeds
[
positions
==
0
]
=
0
# masking inputs at position=0
hidden_states
=
self
.
model
.
model
(
input_ids
=
None
,
inputs_embeds
=
inputs_embeds
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
if
self
.
token_map
is
not
None
:
_logits
=
logits
logits
=
-
torch
.
inf
*
torch
.
ones
(
size
=
(
*
_logits
.
shape
[:
-
1
],
self
.
orig_vocab_size
),
device
=
_logits
.
device
,
dtype
=
_logits
.
dtype
)
logits
[...,
self
.
token_map
]
=
_logits
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
# Also, here's an example script for converting trained EAGLE
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
model_weights
=
{}
for
name
,
loaded_weight
in
weights
:
if
name
==
"token_map"
:
if
self
.
config
.
truncated_vocab_size
<
self
.
config
.
vocab_size
:
self
.
token_map
=
nn
.
Parameter
(
loaded_weight
,
requires_grad
=
False
)
elif
name
.
startswith
(
"fc."
):
weight_loader
=
getattr
(
self
.
fc
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
fc
.
weight
,
loaded_weight
)
elif
name
.
startswith
(
"model.lm_head."
)
or
name
.
startswith
(
"model.model."
):
model_weights
[
name
.
split
(
"model."
,
1
)[
-
1
]]
=
loaded_weight
elif
name
.
startswith
(
"lm_head."
)
or
name
.
startswith
(
"model."
):
model_weights
[
name
]
=
loaded_weight
else
:
model_weights
[
f
"model.
{
name
}
"
]
=
loaded_weight
lm_head_weight
=
model_weights
.
pop
(
"lm_head.weight"
)
if
self
.
token_map
is
not
None
and
\
lm_head_weight
.
shape
[
0
]
>
self
.
token_map
.
shape
[
0
]:
lm_head_weight
=
lm_head_weight
[
self
.
token_map
]
weight_loader
=
getattr
(
self
.
lm_head
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
lm_head
.
weight
,
lm_head_weight
)
self
.
model
.
load_weights
(
model_weights
.
items
())
vllm/model_executor/models/medusa.py
View file @
a3fce56b
...
@@ -30,6 +30,19 @@ class ResidualBlock(nn.Module):
...
@@ -30,6 +30,19 @@ class ResidualBlock(nn.Module):
class
Medusa
(
nn
.
Module
):
class
Medusa
(
nn
.
Module
):
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
Reference implementation: https://github.com/FasterDecoding/Medusa
Differences from reference implementation:
1. Currently this only supports generating proposals from top-1 tokens.
2. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def
__init__
(
self
,
config
:
MedusaConfig
,
**
_
)
->
None
:
def
__init__
(
self
,
config
:
MedusaConfig
,
**
_
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -57,6 +70,12 @@ class Medusa(nn.Module):
...
@@ -57,6 +70,12 @@ class Medusa(nn.Module):
self
.
truncated_vocab_size
,
self
.
truncated_vocab_size
,
logit_scale
)
logit_scale
)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self
.
token_map
=
None
self
.
token_map
=
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
...
...
vllm/sequence.py
View file @
a3fce56b
...
@@ -1092,6 +1092,10 @@ class SamplerOutput(
...
@@ -1092,6 +1092,10 @@ class SamplerOutput(
# Optional last hidden states from the model.
# Optional last hidden states from the model.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
# Time taken in the forward pass for this across all workers
# Time taken in the forward pass for this across all workers
model_forward_time
:
Optional
[
float
]
=
None
model_forward_time
:
Optional
[
float
]
=
None
...
@@ -1176,40 +1180,87 @@ class HiddenStates(msgspec.Struct, array_like=True,
...
@@ -1176,40 +1180,87 @@ class HiddenStates(msgspec.Struct, array_like=True,
omit_defaults
=
True
):
# type: ignore[call-arg]
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Hidden states corresponding to in-progress sequences.
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from
Used in speculative decoding to pass hidden states from
the target model to the proposer model
in the subsequent step
.
the target model to the proposer model.
seq_ids are the sequence ids of each entry of the batch
seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor"""
dimension of the hidden_states tensor"""
# Scorer hidden states. For prefill step, it is used for hidden states of
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
# all tokens, whereas for decode step, it use used for last accepted tokens.
hidden_states
:
torch
.
Tensor
hidden_states
:
torch
.
Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
# Scorer hidden states of the 2nd last token proposed by the proposer (
# irrespective of whether it was accepted or not). Only used for cases when
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored.
second_last_token_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
_seq_ids
:
List
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_seq_ids
:
List
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
if
self
.
seq_group_metadata_list
is
not
None
:
assert
len
(
self
.
seq_group_metadata_list
)
==
len
(
self
.
hidden_states
)
assert
len
(
self
.
seq_group_metadata_list
)
==
len
(
self
.
hidden_states
)
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
@
property
@
property
def
seq_ids
(
self
)
->
List
[
int
]:
def
seq_ids
(
self
)
->
List
[
int
]:
return
self
.
_seq_ids
return
self
.
_seq_ids
def
update
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
def
update
(
self
,
hidden_states
:
torch
.
Tensor
)
->
None
:
hidden_states
:
torch
.
Tensor
,
"""Update hidden states from target model invocation."""
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
second_last_token_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert
len
(
seq_group_metadata_list
)
==
len
(
hidden_states
)
assert
len
(
seq_group_metadata_list
)
==
len
(
hidden_states
)
self
.
_seq_ids
.
extend
(
get_all_seq_ids
(
seq_group_metadata_list
))
self
.
_seq_ids
.
extend
(
get_all_seq_ids
(
seq_group_metadata_list
))
self
.
hidden_states
=
torch
.
cat
([
self
.
hidden_states
,
hidden_states
])
self
.
hidden_states
=
torch
.
cat
([
self
.
hidden_states
,
hidden_states
])
if
self
.
second_last_token_hidden_states
is
not
None
:
# Adding dummy hidden_states to this to maintain same shape
self
.
second_last_token_hidden_states
=
torch
.
cat
([
self
.
second_last_token_hidden_states
,
torch
.
zeros_like
(
hidden_states
)
if
second_last_token_hidden_states
is
None
else
second_last_token_hidden_states
])
def
prune
(
self
,
def
prune
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
None
:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
None
:
"""Prune to provided list of sequence ids."""
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
# seq_group_metadata_list which might cause problems where a sequence
# may be "paused" then "resumed" later. This should only prune sequences
# which are confirmed to be aborted.
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
if
seq_ids
!=
self
.
_seq_ids
:
if
seq_ids
!=
self
.
_seq_ids
:
# Batch contents changed - prune removed sequences.
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_seq_ids
.
index
(
seq_id
)
for
seq_id
in
seq_ids
]
index
=
[
self
.
_seq_ids
.
index
(
seq_id
)
for
seq_id
in
seq_ids
]
self
.
hidden_states
=
self
.
hidden_states
[
index
]
self
.
hidden_states
=
self
.
hidden_states
[
index
]
if
self
.
second_last_token_hidden_states
is
not
None
:
self
.
second_last_token_hidden_states
=
self
\
.
second_last_token_hidden_states
[
index
]
self
.
_seq_ids
=
seq_ids
self
.
_seq_ids
=
seq_ids
def
expand_with_bonus_tokens
(
self
,
seq_with_bonus_token_in_last_step
:
set
)
->
None
:
"""Expand hidden states for sequences with bonus tokens. This is in
alignment with `MultiStepWorker._expand_execute_model_request`."""
if
self
.
second_last_token_hidden_states
is
None
\
or
not
seq_with_bonus_token_in_last_step
:
return
index
=
[]
for
seq_id
in
self
.
_seq_ids
:
i
=
self
.
_seq_ids
.
index
(
seq_id
)
if
seq_id
in
seq_with_bonus_token_in_last_step
:
index
.
append
(
i
+
len
(
self
.
_seq_ids
))
index
.
append
(
i
)
self
.
hidden_states
=
torch
.
cat
(
[
self
.
hidden_states
,
self
.
second_last_token_hidden_states
])[
index
]
class
ExecuteModelRequest
(
class
ExecuteModelRequest
(
msgspec
.
Struct
,
msgspec
.
Struct
,
...
...
vllm/spec_decode/draft_model_runner.py
View file @
a3fce56b
...
@@ -203,6 +203,7 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -203,6 +203,7 @@ class TP1DraftModelRunner(ModelRunner):
self
,
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
)
->
Optional
[
List
[
SamplerOutput
]]:
...
@@ -280,13 +281,30 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -280,13 +281,30 @@ class TP1DraftModelRunner(ModelRunner):
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
(
self
.
graph_runners
[
model_input
.
virtual_engine
]
model_executable
=
(
self
.
graph_runners
[
model_input
.
virtual_engine
]
[
graph_batch_size
])
[
graph_batch_size
])
if
previous_hidden_states
is
not
None
:
hidden_states
=
torch
.
cat
([
previous_hidden_states
,
torch
.
empty
([
graph_batch_size
-
previous_hidden_states
.
shape
[
0
],
*
previous_hidden_states
.
shape
[
1
:]
],
dtype
=
previous_hidden_states
.
dtype
,
device
=
previous_hidden_states
.
device
)
])
else
:
hidden_states
=
None
else
:
else
:
model_executable
=
self
.
model
model_executable
=
self
.
model
hidden_states
=
previous_hidden_states
outputs
:
List
[
SamplerOutput
]
=
[]
outputs
:
List
[
SamplerOutput
]
=
[]
for
step
in
range
(
num_steps
):
for
step
in
range
(
num_steps
):
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
kwargs
=
{
"previous_hidden_states"
:
hidden_states
}
\
if
previous_hidden_states
is
not
None
else
{}
# Run model
# Run model
hidden_states
=
model_executable
(
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
...
@@ -296,6 +314,7 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -296,6 +314,7 @@ class TP1DraftModelRunner(ModelRunner):
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalInputs
.
as_kwargs
(
multi_modal_kwargs
,
**
MultiModalInputs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
device
=
self
.
device
),
**
kwargs
,
)
)
# Compute the logits.
# Compute the logits.
...
...
vllm/spec_decode/multi_step_worker.py
View file @
a3fce56b
...
@@ -4,8 +4,8 @@ from typing import Dict, List, Set, Tuple
...
@@ -4,8 +4,8 @@ from typing import Dict, List, Set, Tuple
import
torch
import
torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
SpeculativeProposer
)
...
@@ -157,6 +157,12 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -157,6 +157,12 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
updated_execute_model_req
.
seq_group_metadata_list
=
\
updated_execute_model_req
.
seq_group_metadata_list
=
\
updated_seq_group_metadata_list
updated_seq_group_metadata_list
if
isinstance
(
updated_execute_model_req
.
previous_hidden_states
,
HiddenStates
):
updated_execute_model_req
.
previous_hidden_states
\
.
expand_with_bonus_tokens
(
seq_with_bonus_token_in_last_step
)
return
updated_execute_model_req
,
indices_of_original_sequence_groups
return
updated_execute_model_req
,
indices_of_original_sequence_groups
@
staticmethod
@
staticmethod
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
a3fce56b
...
@@ -147,6 +147,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -147,6 +147,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_worker_kwargs
[
draft_worker_kwargs
[
"model_runner_cls"
]
=
TP1DraftModelRunner
"model_runner_cls"
]
=
TP1DraftModelRunner
else
:
else
:
if
draft_worker_kwargs
[
"model_config"
].
hf_config
.
model_type
==
"eagle"
:
raise
NotImplementedError
(
"EAGLE does not support TP > 1 yet"
)
allow_zero_draft_token_step
=
False
allow_zero_draft_token_step
=
False
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
...
@@ -355,14 +360,34 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -355,14 +360,34 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req
)
execute_model_req
)
num_lookahead_slots
=
execute_model_req
.
num_lookahead_slots
num_lookahead_slots
=
execute_model_req
.
num_lookahead_slots
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
no_spec
=
num_lookahead_slots
==
0
or
len
(
execute_model_req
.
seq_group_metadata_list
)
==
0
or
disable_all_speculation
# Broadcast how many lookahead slots are scheduled for this step, and
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
# whether all speculation is disabled, to all non-driver workers.
# This is required as if the number of draft model runs changes
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform them.
# communication to inform them.
# no_spec is used to signal non-driver worker about prefill vs decode
# stage. This is needed to ensure that order of execution of proposer
# and scorer is same in both driver and non-driver workers (i.e.,
# scorer -> proposer for prefill and proposer -> scorer in decode). This
# order is needed to support models like EAGLE that take scorer states
# as inputs.
broadcast_dict
=
dict
(
broadcast_dict
=
dict
(
num_lookahead_slots
=
num_lookahead_slots
,
num_lookahead_slots
=
num_lookahead_slots
,
no_spec
=
no_spec
,
disable_all_speculation
=
disable_all_speculation
,
disable_all_speculation
=
disable_all_speculation
,
)
)
broadcast_tensor_dict
(
broadcast_dict
,
src
=
self
.
_driver_rank
)
broadcast_tensor_dict
(
broadcast_dict
,
src
=
self
.
_driver_rank
)
...
@@ -373,17 +398,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -373,17 +398,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
_maybe_disable_speculative_tokens
(
self
.
_maybe_disable_speculative_tokens
(
disable_all_speculation
,
execute_model_req
.
seq_group_metadata_list
)
disable_all_speculation
,
execute_model_req
.
seq_group_metadata_list
)
# Speculative decoding is disabled in the following cases:
if
no_spec
:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
if
num_lookahead_slots
==
0
or
len
(
execute_model_req
.
seq_group_metadata_list
)
==
0
or
disable_all_speculation
:
return
self
.
_run_no_spec
(
execute_model_req
,
return
self
.
_run_no_spec
(
execute_model_req
,
skip_proposer
=
disable_all_speculation
)
skip_proposer
=
disable_all_speculation
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
,
return
self
.
_run_speculative_decoding_step
(
execute_model_req
,
...
@@ -464,8 +479,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -464,8 +479,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
not called, meaning that the kv-cache in proposer for requests is not
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
updated, so they cannot enable spec decode in the rest decoding.
"""
"""
if
not
skip_proposer
:
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
assert
len
(
sampler_output
)
==
1
...
@@ -476,10 +489,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -476,10 +489,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
hidden_states
is
not
None
:
if
hidden_states
is
not
None
:
if
self
.
previous_hidden_states
is
None
:
if
self
.
previous_hidden_states
is
None
:
self
.
previous_hidden_states
=
HiddenStates
(
self
.
previous_hidden_states
=
HiddenStates
(
execute_model_req
.
seq_group_metadata_list
,
hidden_states
)
hidden_states
,
execute_model_req
.
seq_group_metadata_list
)
else
:
else
:
self
.
previous_hidden_states
.
update
(
self
.
previous_hidden_states
.
update
(
execute_model_req
.
seq_group_metadata_list
,
hidden_states
)
hidden_states
,
execute_model_req
.
seq_group_metadata_list
)
if
not
skip_proposer
:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
sampler_output
.
prefill_hidden_states
)
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
...
@@ -507,15 +530,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -507,15 +530,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
return
False
return
False
num_lookahead_slots
=
data
[
"num_lookahead_slots"
]
num_lookahead_slots
=
data
[
"num_lookahead_slots"
]
# Even if num_lookahead_slots is zero, we want to run the proposer model
# In case of prefill, scorer_worker has to be run before proposer so
# as it may have KV.
# that the hidden states can be propagated to proposer when needed.
if
data
[
"no_spec"
]:
self
.
scorer_worker
.
execute_model
()
if
not
data
[
"disable_all_speculation"
]:
# Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV.
#
#
# We run the proposer once per lookahead slot. In the future we
should
# We run the proposer once per lookahead slot. In the future we
#
delegate how many times it runs to the proposer.
# should
delegate how many times it runs to the proposer.
for
_
in
range
(
max
(
num_lookahead_slots
,
1
)):
for
_
in
range
(
max
(
num_lookahead_slots
,
1
)):
self
.
proposer_worker
.
execute_model
()
self
.
proposer_worker
.
execute_model
()
if
not
data
[
"no_spec"
]:
self
.
scorer_worker
.
execute_model
()
self
.
scorer_worker
.
execute_model
()
return
True
return
True
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
...
@@ -546,6 +577,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -546,6 +577,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
"workers generate no tokens"
)
"workers generate no tokens"
)
execute_model_req
.
previous_hidden_states
=
None
with
Timer
()
as
scoring_timer
:
with
Timer
()
as
scoring_timer
:
proposal_scores
=
self
.
scorer
.
score_proposals
(
proposal_scores
=
self
.
scorer
.
score_proposals
(
execute_model_req
,
execute_model_req
,
...
@@ -651,10 +684,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -651,10 +684,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
hs_size
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
hs_size
)
second_last_token_hidden_states
=
hidden_states
[:,
-
2
]
# b x d
hidden_states
=
hidden_states
.
gather
(
1
,
index
).
squeeze
(
1
)
# b x d
hidden_states
=
hidden_states
.
gather
(
1
,
index
).
squeeze
(
1
)
# b x d
# Store hidden states from target model for subsequent decode step
# Store hidden states from target model for subsequent decode step
self
.
previous_hidden_states
=
HiddenStates
(
seq_group_metadata_list
,
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
)
hidden_states
,
seq_group_metadata_list
,
second_last_token_hidden_states
)
return
accepted_token_ids
,
logprobs
return
accepted_token_ids
,
logprobs
...
@@ -951,3 +986,15 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
...
@@ -951,3 +986,15 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
(
proposer_cache_block_size_bytes
+
scorer_cache_block_size_bytes
))
(
proposer_cache_block_size_bytes
+
scorer_cache_block_size_bytes
))
return
new_num_gpu_blocks
return
new_num_gpu_blocks
def
prepare_prefill_hidden_states
(
prefill_hidden_states
:
torch
.
Tensor
)
->
HiddenStates
:
# For prefill step in proposer, we run the model for N-1 tokens
# because Nth token will be processed in the first decode step. For
# N-1 tokens, the input should be 0:N-1 hidden states which should
# be concatanated with 1:N token (since output of scorer has to be
# the input for proposer). Therefore, we shift the hidden states to
# align n-1th hidden state with nth token.
return
HiddenStates
(
prefill_hidden_states
.
roll
(
shifts
=
1
,
dims
=
0
))
if
prefill_hidden_states
is
not
None
else
None
vllm/transformers_utils/config.py
View file @
a3fce56b
...
@@ -11,10 +11,11 @@ from transformers.models.auto.modeling_auto import (
...
@@ -11,10 +11,11 @@ from transformers.models.auto.modeling_auto import (
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
InternVLChatConfig
,
JAISConfig
,
EAGLEConfig
,
InternVLChatConfig
,
MedusaConfig
,
MLPSpeculatorConfig
,
JAISConfig
,
MedusaConfig
,
MPTConfig
,
NemotronConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
RWConfig
,
UltravoxConfig
)
NemotronConfig
,
RWConfig
,
UltravoxConfig
)
if
VLLM_USE_MODELSCOPE
:
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
from
modelscope
import
AutoConfig
...
@@ -32,6 +33,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
...
@@ -32,6 +33,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"jais"
:
JAISConfig
,
"jais"
:
JAISConfig
,
"mlp_speculator"
:
MLPSpeculatorConfig
,
"mlp_speculator"
:
MLPSpeculatorConfig
,
"medusa"
:
MedusaConfig
,
"medusa"
:
MedusaConfig
,
"eagle"
:
EAGLEConfig
,
"internvl_chat"
:
InternVLChatConfig
,
"internvl_chat"
:
InternVLChatConfig
,
"nemotron"
:
NemotronConfig
,
"nemotron"
:
NemotronConfig
,
"ultravox"
:
UltravoxConfig
,
"ultravox"
:
UltravoxConfig
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
a3fce56b
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.eagle
import
EAGLEConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
# `FalconConfig` class from the official HuggingFace transformers library.
...
@@ -20,6 +21,7 @@ __all__ = [
...
@@ -20,6 +21,7 @@ __all__ = [
"InternVLChatConfig"
,
"InternVLChatConfig"
,
"JAISConfig"
,
"JAISConfig"
,
"MedusaConfig"
,
"MedusaConfig"
,
"EAGLEConfig"
,
"MLPSpeculatorConfig"
,
"MLPSpeculatorConfig"
,
"NemotronConfig"
,
"NemotronConfig"
,
"UltravoxConfig"
,
"UltravoxConfig"
,
...
...
vllm/transformers_utils/configs/eagle.py
0 → 100644
View file @
a3fce56b
import
os
from
typing
import
Optional
,
Union
from
transformers
import
AutoConfig
,
PretrainedConfig
class
EAGLEConfig
(
PretrainedConfig
):
model_type
=
"eagle"
def
__init__
(
self
,
model
:
Union
[
PretrainedConfig
,
dict
,
None
]
=
None
,
truncated_vocab_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
model_config
=
None
if
model
is
None
else
(
AutoConfig
.
for_model
(
**
model
)
if
isinstance
(
model
,
dict
)
else
model
)
for
k
,
v
in
kwargs
.
items
():
if
k
!=
"architectures"
and
k
!=
"model_type"
and
hasattr
(
model_config
,
k
):
setattr
(
model_config
,
k
,
v
)
self
.
model
=
model_config
if
self
.
model
is
None
:
self
.
truncated_vocab_size
=
None
else
:
self
.
truncated_vocab_size
=
self
.
model
.
vocab_size
if
\
truncated_vocab_size
is
None
else
truncated_vocab_size
if
"architectures"
not
in
kwargs
:
kwargs
[
"architectures"
]
=
[
"EAGLEModel"
]
super
().
__init__
(
**
kwargs
)
if
self
.
model
is
not
None
:
for
k
,
v
in
self
.
model
.
to_dict
().
items
():
if
not
hasattr
(
self
,
k
):
setattr
(
self
,
k
,
v
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
,
)
->
"EAGLEConfig"
:
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
vllm/worker/model_runner.py
View file @
a3fce56b
import
dataclasses
import
dataclasses
import
gc
import
gc
import
inspect
import
itertools
import
itertools
import
time
import
time
import
warnings
import
warnings
...
@@ -1192,6 +1193,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1192,6 +1193,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_batch_size
=
max
(
_BATCH_SIZES_TO_CAPTURE
)
max_batch_size
=
max
(
_BATCH_SIZES_TO_CAPTURE
)
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
previous_hidden_states
=
None
if
"previous_hidden_states"
in
inspect
.
signature
(
self
.
model
.
forward
).
parameters
:
previous_hidden_states
=
torch
.
empty
(
[
max_batch_size
,
self
.
model_config
.
get_hidden_size
()],
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
intermediate_inputs
=
None
intermediate_inputs
=
None
if
not
get_pp_group
().
is_first_rank
:
if
not
get_pp_group
().
is_first_rank
:
intermediate_inputs
=
self
.
model
.
make_empty_intermediate_tensors
(
intermediate_inputs
=
self
.
model
.
make_empty_intermediate_tensors
(
...
@@ -1264,6 +1277,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1264,6 +1277,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"stream"
:
"stream"
:
graph_capture_context
.
stream
graph_capture_context
.
stream
}
}
if
previous_hidden_states
is
not
None
:
capture_inputs
[
"previous_hidden_states"
]
=
previous_hidden_states
[:
batch_size
]
if
self
.
has_seqlen_agnostic
:
if
self
.
has_seqlen_agnostic
:
# Only used by Mamba-based models CUDA graph atm (Jamba)
# Only used by Mamba-based models CUDA graph atm (Jamba)
capture_inputs
.
update
({
capture_inputs
.
update
({
...
@@ -1462,6 +1480,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1462,6 +1480,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
model_input
.
is_prompt
:
if
model_input
.
is_prompt
:
hidden_states
=
hidden_or_intermediate_states
.
index_select
(
hidden_states
=
hidden_or_intermediate_states
.
index_select
(
0
,
indices
)
0
,
indices
)
output
.
prefill_hidden_states
=
hidden_or_intermediate_states
elif
decode_meta
.
use_cuda_graph
:
elif
decode_meta
.
use_cuda_graph
:
hidden_states
=
hidden_or_intermediate_states
[:
len
(
indices
)]
hidden_states
=
hidden_or_intermediate_states
[:
len
(
indices
)]
else
:
else
:
...
@@ -1510,11 +1529,11 @@ class CUDAGraphRunner:
...
@@ -1510,11 +1529,11 @@ class CUDAGraphRunner:
# Note one iteration is not enough for torch.jit.script
# Note one iteration is not enough for torch.jit.script
for
_
in
range
(
_NUM_WARMUP_ITERS
):
for
_
in
range
(
_NUM_WARMUP_ITERS
):
self
.
model
(
self
.
model
(
input_ids
,
input_ids
=
input_ids
,
positions
,
positions
=
positions
,
kv_caches
,
kv_caches
=
kv_caches
,
attn_metadata
,
attn_metadata
=
attn_metadata
,
intermediate_inputs
,
intermediate_tensors
=
intermediate_inputs
,
**
kwargs
,
**
kwargs
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -1523,11 +1542,11 @@ class CUDAGraphRunner:
...
@@ -1523,11 +1542,11 @@ class CUDAGraphRunner:
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
output_hidden_or_intermediate_states
=
self
.
model
(
output_hidden_or_intermediate_states
=
self
.
model
(
input_ids
,
input_ids
=
input_ids
,
positions
,
positions
=
positions
,
kv_caches
,
kv_caches
=
kv_caches
,
attn_metadata
,
attn_metadata
=
attn_metadata
,
intermediate_inputs
,
intermediate_tensors
=
intermediate_inputs
,
**
kwargs
,
**
kwargs
,
)
)
if
hidden_or_intermediate_states
is
not
None
:
if
hidden_or_intermediate_states
is
not
None
:
...
@@ -1588,6 +1607,11 @@ class CUDAGraphRunner:
...
@@ -1588,6 +1607,11 @@ class CUDAGraphRunner:
if
"seqlen_agnostic_capture_inputs"
in
self
.
input_buffers
:
if
"seqlen_agnostic_capture_inputs"
in
self
.
input_buffers
:
self
.
model
.
copy_inputs_before_cuda_graphs
(
self
.
input_buffers
,
self
.
model
.
copy_inputs_before_cuda_graphs
(
self
.
input_buffers
,
**
kwargs
)
**
kwargs
)
if
"previous_hidden_states"
in
self
.
input_buffers
:
self
.
input_buffers
[
"previous_hidden_states"
].
copy_
(
kwargs
[
"previous_hidden_states"
],
non_blocking
=
True
)
if
intermediate_tensors
is
not
None
:
if
intermediate_tensors
is
not
None
:
for
key
in
intermediate_tensors
.
tensors
:
for
key
in
intermediate_tensors
.
tensors
:
if
key
!=
"model_execute_time"
and
key
!=
"model_forward_time"
:
if
key
!=
"model_execute_time"
and
key
!=
"model_forward_time"
:
...
...
vllm/worker/multi_step_worker.py
View file @
a3fce56b
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
...
@@ -43,7 +45,7 @@ class MultiStepWorker(Worker):
...
@@ -43,7 +45,7 @@ class MultiStepWorker(Worker):
def
_get_driver_input_and_broadcast
(
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Tuple
[
BroadcastableModelInput
,
WorkerInput
]:
)
->
Tuple
[
BroadcastableModelInput
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]
]:
"""
"""
Get the driver input and broadcast it to other workers.
Get the driver input and broadcast it to other workers.
"""
"""
...
@@ -85,7 +87,9 @@ class MultiStepWorker(Worker):
...
@@ -85,7 +87,9 @@ class MultiStepWorker(Worker):
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
return
model_input
,
worker_input
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return
model_input
,
worker_input
,
{}
def
_prepare_last_sampled_token_ids_for_tp_workers
(
def
_prepare_last_sampled_token_ids_for_tp_workers
(
self
,
self
,
...
@@ -130,7 +134,8 @@ class MultiStepWorker(Worker):
...
@@ -130,7 +134,8 @@ class MultiStepWorker(Worker):
def
prepare_input
(
def
prepare_input
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
Optional
[
Tuple
[
StatefulModelInput
,
WorkerInput
]]:
)
->
Optional
[
Tuple
[
StatefulModelInput
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]]:
"""
"""
Depending on the current state of the request and multi step worker,
Depending on the current state of the request and multi step worker,
this method may skip the normal _prepare_model_input and
this method may skip the normal _prepare_model_input and
...
@@ -148,8 +153,8 @@ class MultiStepWorker(Worker):
...
@@ -148,8 +153,8 @@ class MultiStepWorker(Worker):
return
None
return
None
virtual_engine
=
execute_model_req
.
virtual_engine
virtual_engine
=
execute_model_req
.
virtual_engine
model_input
,
worker_input
=
self
.
_get_driver_input_and_broadcast
(
(
model_input
,
worker_input
,
execute_model_req
)
kwargs
)
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
)
assert
isinstance
(
model_input
,
StatefulModelInput
)
assert
isinstance
(
model_input
,
StatefulModelInput
)
if
execute_model_req
.
is_first_multi_step
:
if
execute_model_req
.
is_first_multi_step
:
# cache the worker input and model input for the next steps
# cache the worker input and model input for the next steps
...
@@ -162,7 +167,7 @@ class MultiStepWorker(Worker):
...
@@ -162,7 +167,7 @@ class MultiStepWorker(Worker):
# loop
# loop
if
broadcast_data
is
None
:
if
broadcast_data
is
None
:
return
None
return
None
model_input
,
worker_input
=
broadcast_data
model_input
,
worker_input
,
kwargs
=
broadcast_data
assert
isinstance
(
model_input
,
StatefulModelInput
)
assert
isinstance
(
model_input
,
StatefulModelInput
)
virtual_engine
=
worker_input
.
virtual_engine
virtual_engine
=
worker_input
.
virtual_engine
if
model_input
.
is_first_multi_step
:
if
model_input
.
is_first_multi_step
:
...
@@ -186,4 +191,4 @@ class MultiStepWorker(Worker):
...
@@ -186,4 +191,4 @@ class MultiStepWorker(Worker):
assert
model_input
is
not
None
assert
model_input
is
not
None
assert
worker_input
is
not
None
assert
worker_input
is
not
None
return
model_input
,
worker_input
return
model_input
,
worker_input
,
kwargs
vllm/worker/worker.py
View file @
a3fce56b
...
@@ -86,7 +86,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -86,7 +86,7 @@ class Worker(LocalOrDistributedWorkerBase):
or
(
speculative_config
.
draft_model_config
.
model
==
or
(
speculative_config
.
draft_model_config
.
model
==
model_config
.
model
)
\
model_config
.
model
)
\
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
not
in
[
"medusa"
,
"mlp_speculator"
])
\
not
in
[
"medusa"
,
"mlp_speculator"
,
"eagle"
])
\
else
{
"return_hidden_states"
:
True
}
else
{
"return_hidden_states"
:
True
}
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
...
...
vllm/worker/worker_base.py
View file @
a3fce56b
...
@@ -222,7 +222,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -222,7 +222,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
raise
NotImplementedError
raise
NotImplementedError
def
_get_worker_input_from_broadcast
(
def
_get_worker_input_from_broadcast
(
self
)
->
Optional
[
Tuple
[
BroadcastableModelInput
,
WorkerInput
]]:
self
)
->
Optional
[
Tuple
[
BroadcastableModelInput
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]]:
""" Get the worker input from the broadcasted tensor dict. """
""" Get the worker input from the broadcasted tensor dict. """
assert
self
.
do_metadata_broadcast
assert
self
.
do_metadata_broadcast
assert
not
self
.
is_driver_worker
assert
not
self
.
is_driver_worker
...
@@ -235,11 +237,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -235,11 +237,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
broadcast_data
))
broadcast_data
))
return
model_input
,
worker_input
kwargs
=
extract_previous_hidden_states
(
broadcast_data
)
return
model_input
,
worker_input
,
kwargs
def
_get_driver_input_and_broadcast
(
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Tuple
[
BroadcastableModelInput
,
WorkerInput
]:
)
->
Tuple
[
BroadcastableModelInput
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]
]:
""" Get the driver input and broadcast it to other workers. """
""" Get the driver input and broadcast it to other workers. """
assert
self
.
is_driver_worker
assert
self
.
is_driver_worker
...
@@ -251,17 +255,21 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -251,17 +255,21 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req
.
virtual_engine
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
execute_model_req
.
finished_requests_ids
))
kwargs
=
extract_previous_hidden_states
(
execute_model_req
)
if
self
.
do_metadata_broadcast
:
if
self
.
do_metadata_broadcast
:
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_data
.
update
(
kwargs
)
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
return
model_input
,
worker_input
return
model_input
,
worker_input
,
kwargs
def
prepare_input
(
def
prepare_input
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
Optional
[
Tuple
[
BroadcastableModelInput
,
WorkerInput
]]:
)
->
Optional
[
Tuple
[
BroadcastableModelInput
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]]:
"""
"""
Prepare the inputs to ModelRunner and workers.
Prepare the inputs to ModelRunner and workers.
"""
"""
...
@@ -291,7 +299,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -291,7 +299,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if
inputs
is
None
:
if
inputs
is
None
:
return
None
return
None
model_input
,
worker_input
=
inputs
model_input
,
worker_input
,
kwargs
=
inputs
num_steps
=
worker_input
.
num_steps
num_steps
=
worker_input
.
num_steps
self
.
execute_worker
(
worker_input
)
self
.
execute_worker
(
worker_input
)
...
@@ -312,9 +320,14 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -312,9 +320,14 @@ class LocalOrDistributedWorkerBase(WorkerBase):
"model_execute_time"
,
torch
.
tensor
(
0
)).
item
()
"model_execute_time"
,
torch
.
tensor
(
0
)).
item
()
output
=
self
.
model_runner
.
execute_model
(
output
=
self
.
model_runner
.
execute_model
(
model_input
,
self
.
kv_cache
[
worker_input
.
virtual_engine
]
model_input
=
model_input
,
if
self
.
kv_cache
is
not
None
else
None
,
intermediate_tensors
,
kv_caches
=
self
.
kv_cache
[
worker_input
.
virtual_engine
]
num_steps
)
if
self
.
kv_cache
is
not
None
else
None
,
intermediate_tensors
=
intermediate_tensors
,
num_steps
=
num_steps
,
**
kwargs
,
)
model_execute_time
=
time
.
perf_counter
()
-
start_time
model_execute_time
=
time
.
perf_counter
()
-
start_time
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
# output is IntermediateTensors
# output is IntermediateTensors
...
@@ -360,9 +373,15 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -360,9 +373,15 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if
worker_input
.
num_seq_groups
==
0
:
if
worker_input
.
num_seq_groups
==
0
:
return
[]
return
[]
kwargs
=
extract_previous_hidden_states
(
execute_model_req
)
return
self
.
model_runner
.
execute_model
(
return
self
.
model_runner
.
execute_model
(
model_input
,
self
.
kv_cache
[
worker_input
.
virtual_engine
]
model_input
=
model_input
,
if
self
.
kv_cache
is
not
None
else
None
,
intermediate_tensors
)
kv_caches
=
self
.
kv_cache
[
worker_input
.
virtual_engine
]
if
self
.
kv_cache
is
not
None
else
None
,
intermediate_tensors
=
intermediate_tensors
,
**
kwargs
,
)
class
WorkerWrapperBase
:
class
WorkerWrapperBase
:
...
@@ -439,3 +458,23 @@ class WorkerWrapperBase:
...
@@ -439,3 +458,23 @@ class WorkerWrapperBase:
"This might cause deadlock in distributed execution."
)
"This might cause deadlock in distributed execution."
)
logger
.
exception
(
msg
)
logger
.
exception
(
msg
)
raise
e
raise
e
def
extract_previous_hidden_states
(
data
:
Union
[
ExecuteModelRequest
,
Dict
[
str
,
torch
.
Tensor
]])
->
\
Dict
[
str
,
torch
.
Tensor
]:
"""If data contains previous_hidden_states, extract it. This returns a dict
which can be used directly as additional kwargs in any following
execute_model calls. This is used in draft models like EAGLE."""
output
=
{}
# When called from non-driver worker, data is dict but when called from
# driver worker, data is ExecuteModelRequest.
if
isinstance
(
data
,
dict
):
if
"previous_hidden_states"
in
data
:
output
[
"previous_hidden_states"
]
=
data
[
"previous_hidden_states"
]
elif
data
.
previous_hidden_states
is
not
None
:
output
[
"previous_hidden_states"
]
=
data
.
previous_hidden_states
\
.
hidden_states
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment