Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc8ad684
Unverified
Commit
bc8ad684
authored
May 03, 2024
by
Cody Yu
Committed by
GitHub
May 03, 2024
Browse files
[Misc][Refactor] Introduce ExecuteModelData (#4540)
parent
344bf7cd
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
332 additions
and
486 deletions
+332
-486
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+48
-50
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+29
-35
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+46
-49
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+4
-46
tests/worker/test_swap.py
tests/worker/test_swap.py
+20
-10
vllm/core/scheduler.py
vllm/core/scheduler.py
+4
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+10
-6
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+8
-4
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+10
-27
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+7
-15
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+7
-26
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+14
-19
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+5
-14
vllm/sequence.py
vllm/sequence.py
+31
-1
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+10
-20
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+3
-12
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+21
-33
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+21
-41
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+23
-67
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+11
-11
No files found.
tests/spec_decode/test_multi_step_worker.py
View file @
bc8ad684
...
...
@@ -5,13 +5,12 @@ import pytest
import
torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
.utils
import
(
assert_logprobs_dict_allclose
,
create_batch
,
create_execute_model_data
,
create_seq_group_metadata_from_prompts
,
create_worker
,
patch_execute_model_with_seeds
,
zero_kv_cache
)
...
...
@@ -105,31 +104,32 @@ def test_same_output_for_single_step():
final_prompt_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
multi_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
multi_step_seq_group
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
))
single_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
))
final_prompt_lens
=
final_prompt_lens
)
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
actual_output
,
_
=
multi_step_worker
.
sampler_output
(
**
multi_step_execute_model_data
.
to_dict
(),
sample_len
=
num_steps
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
multi_step_seq_group
),
sample_len
=
num_steps
)
assert
len
(
actual_output
)
==
num_steps
actual_output
=
actual_output
[
0
]
single_step_seq_group
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
zero_kv_cache
(
worker
.
cache_engine
)
set_random_seed
(
seed
)
expected_output
=
worker
.
execute_model
(
**
single_step_execute_model_data
.
to_dict
(),
)[
0
]
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
single_step_seq_group
))[
0
]
actual_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
actual_output
...
...
@@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
continuations
=
[[
1
]
for
_
in
prompts
]
execute_model_data
=
create_execute_model_data
(
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
,
)
final_prompt_lens
=
final_prompt_lens
)
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
**
execute_model_data
.
to_dict
(),
sample_len
=
num_steps
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
num_steps
)
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
...
...
@@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
for
_
in
multi_step_output
:
execute_model_data
=
create_execute_model_data
(
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
)
final_prompt_lens
=
final_prompt_lens
)
single_step_output
.
extend
(
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
))
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
)))
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
...
...
@@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
)
for
_
in
range
(
k
)
],
True
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
proposals
=
proposer
.
get_proposals
(
**
execute_model_data
.
to_dict
(),
proposal_len
=
k
,
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
@@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
max_proposal_len
=
prompt_len
+
k
-
1
,
)
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prompt_len
=
prompt_len
)
proposals
=
proposer
.
get_proposals
(
**
execute_model_data
.
to_dict
(),
proposal_len
=
k
,
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
@@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
)
for
_
in
range
(
k
)
],
True
execute_model_data
,
_
,
_
=
create_batch
(
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prompt_len
=
prompt_len
,
prev_output_token_len
=
prev_output_token_len
,
)
proposals
=
proposer
.
get_proposals
(
**
execute_model_data
.
to_dict
(),
proposal_len
=
k
,
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_ngram_worker.py
View file @
bc8ad684
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.utils
import
(
create_execute_model_data
,
create_seq_group_metadata_from_prompts
,
create_worker
)
from
.utils
import
create_seq_group_metadata_from_prompts
,
create_worker
def
test_ngram_algo_correctness_for_single_no_match
():
...
...
@@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
proposal_len
=
proposal_len
,
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
@@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
proposal_len
=
proposal_len
,
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
@@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
proposal_len
=
proposal_len
,
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
bc8ad684
...
...
@@ -7,7 +7,7 @@ import torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
...
...
@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
split_num_cache_blocks_evenly
)
from
.utils
import
(
ExecuteModelData
,
create_batch
,
create_sampler_output_list
,
mock_worker
)
from
.utils
import
create_batch
,
create_sampler_output_list
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
...
...
@@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
call_args_list
=
draft_worker
.
get_spec_proposals
.
call_args_list
assert
len
(
call_args_list
)
==
1
for
args
,
_
in
call_args_list
:
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
actual_k
)
=
args
actual_execute_model_data
=
ExecuteModelData
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
assert
actual_execute_model_data
==
execute_model_data
assert
actual_k
==
k
actual_execute_model_data
=
args
[
0
]
assert
actual_execute_model_data
==
execute_model_req
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
...
...
@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
seq_group_metadata_list
,
prompts
,
prev_output_tokens
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
...
...
@@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker
.
execute_model
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
seen_contexts
=
[]
call_args_list
=
target_worker
.
execute_model
.
call_args_list
assert
len
(
call_args_list
)
==
1
for
args
,
kwargs
in
call_args_list
:
target_execute_model_data
=
ExecuteModelData
.
from_dict
(
kwargs
)
for
_
,
kwargs
in
call_args_list
:
seq_group_metadata_list
=
kwargs
[
"execute_model_req"
].
seq_group_metadata_list
assert
len
(
target_execute_model_data
.
seq_group_metadata_list
)
==
(
k
+
1
)
*
batch_size
for
seq_group_metadata
in
(
target_execute_model_data
.
seq_group_metadata_list
):
assert
len
(
seq_group_metadata_list
)
==
(
k
+
1
)
*
batch_size
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
...
...
@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
...
...
@@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
rejection_sampler
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
len
(
rejection_sampler
.
call_args_list
)
==
1
_
,
kwargs
=
rejection_sampler
.
call_args_list
[
0
]
...
...
@@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
...
...
@@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler
.
return_value
=
rejection_sampler_output
output
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
expected_output
=
create_sampler_output_list
(
token_ids
=
rejection_sampler_output
.
transpose
(
0
,
1
),
...
...
@@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
seq_ids
=
[
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
for
seq_group_metadata
in
execute_model_data
.
seq_group_metadata_list
for
seq_group_metadata
in
seq_group_metadata_list
]
actual_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
expected_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
...
...
@@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
...
...
@@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
metrics_collector
.
maybe_collect_rejsample_metrics
.
return_value
=
(
mock_rejsample_metrics
)
output
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
output
[
0
].
spec_decode_worker_metrics
==
mock_rejsample_metrics
call_args_list
=
(
...
...
@@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
batch_size
,
k
,
prev_output_token_len
=
0
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prev_output_token_len
=
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
target_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
...
...
@@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
batch_size
,
k
,
prev_output_token_len
=
0
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prev_output_token_len
=
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
target_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
@
pytest
.
mark
.
skip_global_cleanup
...
...
tests/spec_decode/utils.py
View file @
bc8ad684
from
dataclasses
import
dataclass
,
fields
from
itertools
import
count
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
from
unittest.mock
import
MagicMock
...
...
@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
from
vllm.worker.worker
import
Worker
@
dataclass
class
ExecuteModelData
:
"""Helper data structure which facilitates cleaner tests.
"""
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
def
to_dict
(
self
):
return
dict
(
(
field
.
name
,
getattr
(
self
,
field
.
name
))
for
field
in
fields
(
self
))
@
classmethod
def
from_dict
(
cls
,
d
):
cleaned
=
dict
((
field
.
name
,
d
[
field
.
name
])
for
field
in
fields
(
cls
))
return
cls
(
**
cleaned
)
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
def
create_execute_model_data
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
)
->
ExecuteModelData
:
if
blocks_to_swap_in
is
None
:
blocks_to_swap_in
=
{}
if
blocks_to_swap_out
is
None
:
blocks_to_swap_out
=
{}
if
blocks_to_copy
is
None
:
blocks_to_copy
=
{}
return
ExecuteModelData
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
def
mock_worker
(
cls
=
None
,
vocab_size
:
int
=
30_000
,
max_model_len
:
int
=
2048
,
...
...
@@ -258,8 +217,7 @@ def create_batch(batch_size,
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
execute_model_data
=
create_execute_model_data
(
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
prev_output_tokens
,
seq_ids
),
)
return
execute_model_data
,
prompts
,
prev_output_tokens
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
prev_output_tokens
,
seq_ids
)
return
seq_group_metadata_list
,
prompts
,
prev_output_tokens
tests/worker/test_swap.py
View file @
bc8ad684
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.worker
import
Worker
...
...
@@ -54,10 +55,14 @@ def test_swap() -> None:
# Test swap out.
blocks_to_swap_out
=
{
3
:
72
,
56
:
35
,
84
:
34
}
worker
.
execute_model
(
seq_group_metadata_list
=
[],
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
[],
blocks_to_swap_in
=
{},
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
{})
blocks_to_copy
=
{},
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
...
...
@@ -66,14 +71,19 @@ def test_swap() -> None:
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
# Test swap in.
blocks_to_swap_in
=
{
19
:
45
,
67
:
23
,
12
:
78
,
40
:
99
,
1
:
71
}
worker
.
execute_model
(
seq_group_metadata_list
=
[],
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
{},
blocks_to_copy
=
{})
execute_model_req
.
blocks_to_swap_out
=
{}
execute_model_req
.
blocks_to_swap_in
=
{
19
:
45
,
67
:
23
,
12
:
78
,
40
:
99
,
1
:
71
}
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
blocks_to_swap_in
.
items
():
for
src
,
dst
in
execute_model_req
.
blocks_to_swap_in
.
items
():
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
vllm/core/scheduler.py
View file @
bc8ad684
...
...
@@ -128,6 +128,8 @@ class SchedulerOutputs:
ignored_seq_groups
:
List
[
SequenceGroup
]
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
# The number of requests in the running queue
running_queue_size
:
int
def
__post_init__
(
self
):
# Swap in and swap out should never happen at the same time.
...
...
@@ -797,6 +799,7 @@ class Scheduler:
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
swapped_in
.
infeasible_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
)
def
_schedule_chunked_prefill
(
self
):
...
...
@@ -883,6 +886,7 @@ class Scheduler:
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
)
def
_schedule
(
self
)
->
SchedulerOutputs
:
...
...
vllm/engine/async_llm_engine.py
View file @
bc8ad684
...
...
@@ -16,7 +16,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
...
...
@@ -210,12 +210,16 @@ class _AsyncLLMEngine(LLMEngine):
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
)
output
=
await
self
.
model_executor
.
execute_model_async
(
seq_group_metadata_list
,
scheduler_outputs
.
blocks_to_swap_in
,
scheduler_outputs
.
blocks_to_swap_out
,
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
)
execute_model_req
)
else
:
output
=
[]
...
...
vllm/engine/llm_engine.py
View file @
bc8ad684
...
...
@@ -22,8 +22,8 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
...
...
@@ -583,12 +583,16 @@ class LLMEngine:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
not
scheduler_outputs
.
is_empty
():
output
=
self
.
model_executor
.
e
xecute
_m
odel
(
execute_model_req
=
E
xecute
M
odel
Request
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
)
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
)
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
else
:
output
=
[]
...
...
vllm/executor/cpu_executor.py
View file @
bc8ad684
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
import
torch
...
...
@@ -7,7 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
...
...
@@ -72,18 +72,10 @@ class CPUExecutor(ExecutorBase):
logger
.
info
(
"# CPU blocks: %d"
,
num_gpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
...
@@ -105,18 +97,9 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
)
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
execute_model_req
=
execute_model_req
,
)
return
output
async
def
check_health_async
(
self
)
->
None
:
...
...
vllm/executor/executor_base.py
View file @
bc8ad684
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VisionLanguageConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
class
ExecutorBase
(
ABC
):
...
...
@@ -68,12 +68,9 @@ class ExecutorBase(ABC):
raise
NotImplementedError
@
abstractmethod
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Executes at least one model step on the given sequences."""
raise
NotImplementedError
...
...
@@ -108,12 +105,7 @@ class ExecutorAsyncBase(ExecutorBase):
@
abstractmethod
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Executes one model step on the given sequences."""
raise
NotImplementedError
...
...
vllm/executor/gpu_executor.py
View file @
bc8ad684
...
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -118,19 +118,8 @@ class GPUExecutor(ExecutorBase):
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
,
)
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
...
@@ -154,16 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
)
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
execute_model_req
=
execute_model_req
,
)
return
output
vllm/executor/neuron_executor.py
View file @
bc8ad684
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
make_async
logger
=
init_logger
(
__name__
)
...
...
@@ -45,20 +45,18 @@ class NeuronExecutor(ExecutorBase):
"""
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
assert
(
blocks_to_swap_in
==
{}
and
blocks_to_swap_out
==
{}
and
blocks_to_copy
==
{}),
(
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
assert
(
execute_model_req
.
blocks_to_swap_in
==
{}
and
execute_model_req
.
blocks_to_swap_out
==
{}
and
execute_model_req
.
blocks_to_copy
==
{}),
(
"Cache operations are not supported for Neuron backend."
)
assert
num_lookahead_slots
==
0
,
(
assert
execute_model_req
.
num_lookahead_slots
==
0
,
(
"lookahead not supported for Neuron backend."
)
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
)
execute_model_req
.
seq_group_metadata_list
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
...
@@ -80,14 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
)
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
)
return
output
async
def
check_health_async
(
self
)
->
None
:
...
...
vllm/executor/ray_gpu_executor.py
View file @
bc8ad684
...
...
@@ -10,7 +10,7 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.logger
import
init_logger
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
...
...
@@ -166,21 +166,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
)
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
=
0
)
->
List
[
SamplerOutput
]:
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
"num_lookahead_slots"
:
num_lookahead_slots
,
},
driver_kwargs
=
{
"execute_model_req"
:
execute_model_req
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
# Only the driver worker returns the sampling results.
...
...
vllm/sequence.py
View file @
bc8ad684
"""Sequence and its related classes."""
import
copy
import
enum
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
from
vllm.block
import
LogicalTokenBlock
...
...
@@ -734,3 +734,33 @@ class SamplerOutput:
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
@
dataclass
class
ExecuteModelRequest
:
"""The model execution request."""
# The sequence group metadata list.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
field
(
default_factory
=
dict
)
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
=
0
# The number of requests in the running queue.
running_queue_size
:
int
=
0
def
clone
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
"ExecuteModelRequest"
:
"""Clone the request with a new sequence group metadata list."""
return
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
self
.
blocks_to_swap_in
.
copy
(),
blocks_to_swap_out
=
self
.
blocks_to_swap_out
.
copy
(),
blocks_to_copy
=
self
.
blocks_to_copy
.
copy
(),
num_lookahead_slots
=
self
.
num_lookahead_slots
,
running_queue_size
=
self
.
running_queue_size
,
)
vllm/spec_decode/batch_expansion.py
View file @
bc8ad684
from
itertools
import
chain
,
count
from
typing
import
Dict
,
Iterator
,
List
,
Optional
,
Tuple
from
typing
import
Iterator
,
List
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
...
...
@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
"""Score the proposed tokens via the scorer model.
...
...
@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
no speculation is produced for that sequence.
Args:
seq_group_metadata_list: The input sequence group metadata.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
...
...
@@ -80,28 +73,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
seq_group_metadata_list
=
target_seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
,
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
,
spec_logprobs
=
self
.
_contract_batch
(
contracted_bs
=
len
(
seq_group_metadata_list
),
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
k
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
return
SpeculativeScores
(
...
...
vllm/spec_decode/interfaces.py
View file @
bc8ad684
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
import
torch
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
@
dataclass
...
...
@@ -58,11 +57,7 @@ class SpeculativeProposer(ABC):
@
abstractmethod
def
get_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
raise
NotImplementedError
...
...
@@ -72,11 +67,7 @@ class SpeculativeScorer(ABC):
@
abstractmethod
def
score_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
raise
NotImplementedError
vllm/spec_decode/multi_step_worker.py
View file @
bc8ad684
import
copy
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
List
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
...
...
@@ -44,10 +45,7 @@ class MultiStepWorker(Worker):
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
...
...
@@ -57,26 +55,24 @@ class MultiStepWorker(Worker):
For multi step worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
self
.
_raise_if_unsupported
(
execute_model_req
)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list
=
self
.
_shallow_copy_inputs
(
seq_group_metadata_list
)
execute_model_req
.
seq_group_metadata_list
)
copied_execute_model_req
=
execute_model_req
.
clone
(
copied_seq_group_metadata_list
)
# Assert enough KV space for sample_len tokens per sequence.
self
.
_assert_enough_kv_space
(
seq_group_metadata_list
,
sample_len
)
self
.
_assert_enough_kv_space
(
execute_model_req
.
seq_group_metadata_list
,
sample_len
)
# Run model sample_len times.
model_outputs
=
[]
for
_
in
range
(
sample_len
):
model_output
=
super
().
execute_model
(
seq_group_metadata_list
=
copied_seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
execute_model_req
=
copied_execute_model_req
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
...
...
@@ -89,23 +85,13 @@ class MultiStepWorker(Worker):
def
get_spec_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
max_proposal_len
,
)
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
def
_append_new_tokens
(
self
,
model_output
:
SamplerOutput
,
...
...
@@ -196,20 +182,22 @@ class MultiStepWorker(Worker):
def
_raise_if_unsupported
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]):
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"MultiStepWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
seq_group_metadata_list
):
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"MultiStepWorker does not support beam search."
)
vllm/spec_decode/ngram_worker.py
View file @
bc8ad684
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
...
...
@@ -46,13 +46,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# NGram don't need gpu sampler
pass
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
)
->
None
:
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
None
:
"""NGram doesn't depend on model execution, just pass this function"""
pass
...
...
@@ -71,10 +65,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
def
sampler_output
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
"""NGram match algo to pick proposal candidate. Returns the list of
...
...
@@ -83,16 +74,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
"""
self
.
_raise_if_unsupported
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
)
self
.
_raise_if_unsupported
(
execute_model_req
)
arr
=
[]
has_spec_out
=
False
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
:
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
...
...
@@ -135,17 +121,19 @@ class NGramWorker(LoraNotSupportedWorkerBase):
indices
=
token_ids
.
unsqueeze
(
2
)
token_probs
=
torch
.
zeros
(
(
len
(
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
(
len
(
execute_model_req
.
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
token_logprobs
=
torch
.
zeros
(
(
len
(
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
(
len
(
execute_model_req
.
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
for
i
in
range
(
len
(
seq_group_metadata_list
)):
for
i
in
range
(
len
(
execute_model_req
.
seq_group_metadata_list
)):
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
...
...
@@ -157,40 +145,32 @@ class NGramWorker(LoraNotSupportedWorkerBase):
def
get_spec_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
max_proposal_len
,
)
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
def
_raise_if_unsupported
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]):
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"NGramWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
seq_group_metadata_list
):
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"NGramWorker does not support beam search."
)
vllm/spec_decode/spec_decode_worker.py
View file @
bc8ad684
from
functools
import
cached_property
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
...
...
@@ -190,68 +191,36 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Perform speculative decoding on the input batch.
"""
assert
seq_group_metadata_list
is
not
None
,
(
assert
execute_model_req
.
seq_group_metadata_list
is
not
None
,
(
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
# num_lookahead_slots)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if
num_lookahead_slots
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
if
execute_model_req
.
num_lookahead_slots
==
0
or
len
(
execute_model_req
.
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
execute_model_req
)
return
self
.
_run_speculative_decoding_step
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
k
=
num_lookahead_slots
,
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
)
->
List
[
SamplerOutput
]:
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
"""
#logger.info("run proposer worker no spec")
self
.
proposer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
#logger.info("run target worker no spec")
sampler_output
=
self
.
scorer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
...
...
@@ -265,12 +234,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
def
_run_speculative_decoding_step
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
)
->
List
[
SamplerOutput
]:
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
...
...
@@ -282,33 +246,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
#logger.info("get spec proposals")
# Generate proposals using draft worker.
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
)
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
)
#logger.info("score proposals")
proposal_scores
=
self
.
scorer
.
score_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
,
execute_model_req
,
proposals
,
)
#logger.info("verify proposals")
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
#logger.info("create output list")
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
accepted_token_ids
,
target_logprobs
=
target_logprobs
,
k
=
k
)
k
=
execute_model_req
.
num_lookahead_slots
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
...
...
vllm/spec_decode/top1_proposer.py
View file @
bc8ad684
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.util
import
sampler_output_to_torch
...
...
@@ -40,17 +41,15 @@ class Top1Proposer(SpeculativeProposer):
def
get_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
proposal_len
:
int
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
proposal_len
=
execute_model_req
.
num_lookahead_slots
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
...
...
@@ -66,11 +65,12 @@ class Top1Proposer(SpeculativeProposer):
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_outpu
t
(
nonzero_execute_model_req
=
ExecuteModelReques
t
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
proposal_len
,
)
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_output
(
execute_model_req
=
nonzero_execute_model_req
,
sample_len
=
proposal_len
,
)
else
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment