Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc8ad684
Unverified
Commit
bc8ad684
authored
May 03, 2024
by
Cody Yu
Committed by
GitHub
May 03, 2024
Browse files
[Misc][Refactor] Introduce ExecuteModelData (#4540)
parent
344bf7cd
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
332 additions
and
486 deletions
+332
-486
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+48
-50
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+29
-35
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+46
-49
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+4
-46
tests/worker/test_swap.py
tests/worker/test_swap.py
+20
-10
vllm/core/scheduler.py
vllm/core/scheduler.py
+4
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+10
-6
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+8
-4
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+10
-27
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+7
-15
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+7
-26
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+14
-19
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+5
-14
vllm/sequence.py
vllm/sequence.py
+31
-1
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+10
-20
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+3
-12
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+21
-33
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+21
-41
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+23
-67
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+11
-11
No files found.
tests/spec_decode/test_multi_step_worker.py
View file @
bc8ad684
...
@@ -5,13 +5,12 @@ import pytest
...
@@ -5,13 +5,12 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
from
.utils
import
(
assert_logprobs_dict_allclose
,
create_batch
,
from
.utils
import
(
assert_logprobs_dict_allclose
,
create_batch
,
create_execute_model_data
,
create_seq_group_metadata_from_prompts
,
create_worker
,
create_seq_group_metadata_from_prompts
,
create_worker
,
patch_execute_model_with_seeds
,
zero_kv_cache
)
patch_execute_model_with_seeds
,
zero_kv_cache
)
...
@@ -105,31 +104,32 @@ def test_same_output_for_single_step():
...
@@ -105,31 +104,32 @@ def test_same_output_for_single_step():
final_prompt_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
final_prompt_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
multi_step_execute_model_data
=
create_execute_model_data
(
multi_step_seq_group
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
))
single_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
))
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
actual_output
,
_
=
multi_step_worker
.
sampler_output
(
actual_output
,
_
=
multi_step_worker
.
sampler_output
(
**
multi_step_execute_model_data
.
to_dict
(),
sample_len
=
num_steps
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
multi_step_seq_group
),
sample_len
=
num_steps
)
assert
len
(
actual_output
)
==
num_steps
assert
len
(
actual_output
)
==
num_steps
actual_output
=
actual_output
[
0
]
actual_output
=
actual_output
[
0
]
single_step_seq_group
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
expected_output
=
worker
.
execute_model
(
expected_output
=
worker
.
execute_model
(
**
single_step_execute_model_data
.
to_dict
(),
)[
0
]
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
single_step_seq_group
))[
0
]
actual_token_ids
=
[
actual_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
actual_output
output
.
samples
[
0
].
output_token
for
output
in
actual_output
...
@@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
...
@@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
continuations
=
[[
1
]
for
_
in
prompts
]
continuations
=
[[
1
]
for
_
in
prompts
]
execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
continuations
=
continuations
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
),
)
# Run multi-step.
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
**
execute_model_data
.
to_dict
(),
sample_len
=
num_steps
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
num_steps
)
# Run single-step repeatedly.
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
...
@@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
...
@@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
for
_
in
multi_step_output
:
for
_
in
multi_step_output
:
execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
continuations
=
continuations
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
))
single_step_output
.
extend
(
single_step_output
.
extend
(
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
))
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
)))
# Append output tokens to new sequence data.
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
...
@@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
...
@@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
)
for
_
in
range
(
k
)
)
for
_
in
range
(
k
)
],
True
],
True
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
proposals
=
proposer
.
get_proposals
(
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
**
execute_model_data
.
to_dict
(),
seq_group_metadata_list
=
seq_group_metadata_list
,
proposal_len
=
k
,
num_lookahead_slots
=
k
),
)
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
...
@@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
max_proposal_len
=
prompt_len
+
k
-
1
,
max_proposal_len
=
prompt_len
+
k
-
1
,
)
)
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
k
,
prompt_len
=
prompt_len
)
prompt_len
=
prompt_len
)
proposals
=
proposer
.
get_proposals
(
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
**
execute_model_data
.
to_dict
(),
seq_group_metadata_list
=
seq_group_metadata_list
,
proposal_len
=
k
,
num_lookahead_slots
=
k
),
)
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
...
@@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
)
for
_
in
range
(
k
)
)
for
_
in
range
(
k
)
],
True
],
True
execute_model_data
,
_
,
_
=
create_batch
(
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
batch_size
,
k
,
k
,
prompt_len
=
prompt_len
,
prompt_len
=
prompt_len
,
prev_output_token_len
=
prev_output_token_len
,
prev_output_token_len
=
prev_output_token_len
,
)
)
proposals
=
proposer
.
get_proposals
(
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
**
execute_model_data
.
to_dict
(),
seq_group_metadata_list
=
seq_group_metadata_list
,
proposal_len
=
k
,
num_lookahead_slots
=
k
),
)
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_ngram_worker.py
View file @
bc8ad684
import
torch
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.utils
import
(
create_execute_model_data
,
from
.utils
import
create_seq_group_metadata_from_prompts
,
create_worker
create_seq_group_metadata_from_prompts
,
create_worker
)
def
test_ngram_algo_correctness_for_single_no_match
():
def
test_ngram_algo_correctness_for_single_no_match
():
...
@@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():
...
@@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():
proposal_len
=
5
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
))
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
**
ngram_sampler_output_data
.
to_dict
(),
num_lookahead_slots
=
proposal_len
),
)
proposal_len
=
proposal_len
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
...
@@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
proposal_len
=
5
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
))
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
**
ngram_sampler_output_data
.
to_dict
(),
num_lookahead_slots
=
proposal_len
),
)
proposal_len
=
proposal_len
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():
...
@@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():
proposal_len
=
5
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
))
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
**
ngram_sampler_output_data
.
to_dict
(),
num_lookahead_slots
=
proposal_len
),
)
proposal_len
=
proposal_len
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
bc8ad684
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
SpecDecodeWorkerMetrics
)
...
@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
...
@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
split_num_cache_blocks_evenly
)
split_num_cache_blocks_evenly
)
from
.utils
import
(
ExecuteModelData
,
create_batch
,
create_sampler_output_list
,
from
.utils
import
create_batch
,
create_sampler_output_list
,
mock_worker
mock_worker
)
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
...
@@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
...
@@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
exception_secret
=
'artificial stop'
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
num_lookahead_slots
=
k
)
call_args_list
=
draft_worker
.
get_spec_proposals
.
call_args_list
call_args_list
=
draft_worker
.
get_spec_proposals
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
for
args
,
_
in
call_args_list
:
for
args
,
_
in
call_args_list
:
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
actual_execute_model_data
=
args
[
0
]
blocks_to_copy
,
actual_k
)
=
args
assert
actual_execute_model_data
==
execute_model_req
actual_execute_model_data
=
ExecuteModelData
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
assert
actual_execute_model_data
==
execute_model_data
assert
actual_k
==
k
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
...
@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
device
=
'cuda'
)
*
k
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
seq_group_metadata_list
,
prompts
,
prev_output_tokens
=
create_batch
(
batch_size
,
k
)
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
...
@@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker
.
execute_model
.
side_effect
=
ValueError
(
exception_secret
)
target_worker
.
execute_model
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
seen_contexts
=
[]
seen_contexts
=
[]
call_args_list
=
target_worker
.
execute_model
.
call_args_list
call_args_list
=
target_worker
.
execute_model
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
for
args
,
kwargs
in
call_args_list
:
for
_
,
kwargs
in
call_args_list
:
target_execute_model_data
=
ExecuteModelData
.
from_dict
(
kwargs
)
seq_group_metadata_list
=
kwargs
[
"execute_model_req"
].
seq_group_metadata_list
assert
len
(
target_execute_model_data
.
seq_group_metadata_list
)
==
(
assert
len
(
seq_group_metadata_list
)
==
(
k
+
1
)
*
batch_size
k
+
1
)
*
batch_size
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
(
target_execute_model_data
.
seq_group_metadata_list
):
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
...
@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
...
@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_token_ids
=
proposal_token_ids
,
...
@@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
...
@@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
rejection_sampler
.
side_effect
=
ValueError
(
exception_secret
)
rejection_sampler
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
len
(
rejection_sampler
.
call_args_list
)
==
1
assert
len
(
rejection_sampler
.
call_args_list
)
==
1
_
,
kwargs
=
rejection_sampler
.
call_args_list
[
0
]
_
,
kwargs
=
rejection_sampler
.
call_args_list
[
0
]
...
@@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_token_ids
=
proposal_token_ids
,
...
@@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler
.
return_value
=
rejection_sampler_output
rejection_sampler
.
return_value
=
rejection_sampler_output
output
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
expected_output
=
create_sampler_output_list
(
expected_output
=
create_sampler_output_list
(
token_ids
=
rejection_sampler_output
.
transpose
(
0
,
1
),
token_ids
=
rejection_sampler_output
.
transpose
(
0
,
1
),
...
@@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
seq_ids
=
[
seq_ids
=
[
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
for
seq_group_metadata
in
execute_model_data
.
seq_group_metadata_list
for
seq_group_metadata
in
seq_group_metadata_list
]
]
actual_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
actual_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
expected_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
expected_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
...
@@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_token_ids
=
proposal_token_ids
,
...
@@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
metrics_collector
.
maybe_collect_rejsample_metrics
.
return_value
=
(
metrics_collector
.
maybe_collect_rejsample_metrics
.
return_value
=
(
mock_rejsample_metrics
)
mock_rejsample_metrics
)
output
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
output
[
0
].
spec_decode_worker_metrics
==
mock_rejsample_metrics
assert
output
[
0
].
spec_decode_worker_metrics
==
mock_rejsample_metrics
call_args_list
=
(
call_args_list
=
(
...
@@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
...
@@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
metrics_collector
)
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
batch_size
,
k
,
prev_output_token_len
=
0
)
k
,
prev_output_token_len
=
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
out
=
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
num_lookahead_slots
=
k
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
assert
out
[
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
**
execute_model_data
.
to_dict
())
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
...
@@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
...
@@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
metrics_collector
)
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
batch_size
,
k
,
prev_output_token_len
=
0
)
k
,
prev_output_token_len
=
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
out
=
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
num_lookahead_slots
=
k
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
assert
out
[
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
**
execute_model_data
.
to_dict
())
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
...
...
tests/spec_decode/utils.py
View file @
bc8ad684
from
dataclasses
import
dataclass
,
fields
from
itertools
import
count
from
itertools
import
count
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
...
@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
...
@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
@
dataclass
class
ExecuteModelData
:
"""Helper data structure which facilitates cleaner tests.
"""
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
def
to_dict
(
self
):
return
dict
(
(
field
.
name
,
getattr
(
self
,
field
.
name
))
for
field
in
fields
(
self
))
@
classmethod
def
from_dict
(
cls
,
d
):
cleaned
=
dict
((
field
.
name
,
d
[
field
.
name
])
for
field
in
fields
(
cls
))
return
cls
(
**
cleaned
)
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
return
(
seq_len
+
block_size
-
1
)
//
block_size
def
create_execute_model_data
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
)
->
ExecuteModelData
:
if
blocks_to_swap_in
is
None
:
blocks_to_swap_in
=
{}
if
blocks_to_swap_out
is
None
:
blocks_to_swap_out
=
{}
if
blocks_to_copy
is
None
:
blocks_to_copy
=
{}
return
ExecuteModelData
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
def
mock_worker
(
cls
=
None
,
def
mock_worker
(
cls
=
None
,
vocab_size
:
int
=
30_000
,
vocab_size
:
int
=
30_000
,
max_model_len
:
int
=
2048
,
max_model_len
:
int
=
2048
,
...
@@ -258,8 +217,7 @@ def create_batch(batch_size,
...
@@ -258,8 +217,7 @@ def create_batch(batch_size,
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
]
execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
block_size
,
final_prompt_lens
,
prev_output_tokens
,
seq_ids
)
prev_output_tokens
,
seq_ids
),
)
return
seq_group_metadata_list
,
prompts
,
prev_output_tokens
return
execute_model_data
,
prompts
,
prev_output_tokens
tests/worker/test_swap.py
View file @
bc8ad684
import
torch
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
...
@@ -54,10 +55,14 @@ def test_swap() -> None:
...
@@ -54,10 +55,14 @@ def test_swap() -> None:
# Test swap out.
# Test swap out.
blocks_to_swap_out
=
{
3
:
72
,
56
:
35
,
84
:
34
}
blocks_to_swap_out
=
{
3
:
72
,
56
:
35
,
84
:
34
}
worker
.
execute_model
(
seq_group_metadata_list
=
[],
execute_model_req
=
ExecuteModelRequest
(
blocks_to_swap_in
=
{},
seq_group_metadata_list
=
[],
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_in
=
{},
blocks_to_copy
=
{})
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
{},
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
...
@@ -66,14 +71,19 @@ def test_swap() -> None:
...
@@ -66,14 +71,19 @@ def test_swap() -> None:
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
# Test swap in.
# Test swap in.
blocks_to_swap_in
=
{
19
:
45
,
67
:
23
,
12
:
78
,
40
:
99
,
1
:
71
}
execute_model_req
.
blocks_to_swap_out
=
{}
worker
.
execute_model
(
seq_group_metadata_list
=
[],
execute_model_req
.
blocks_to_swap_in
=
{
blocks_to_swap_in
=
blocks_to_swap_in
,
19
:
45
,
blocks_to_swap_out
=
{},
67
:
23
,
blocks_to_copy
=
{})
12
:
78
,
40
:
99
,
1
:
71
}
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
blocks_to_swap_in
.
items
():
for
src
,
dst
in
execute_model_req
.
blocks_to_swap_in
.
items
():
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
vllm/core/scheduler.py
View file @
bc8ad684
...
@@ -128,6 +128,8 @@ class SchedulerOutputs:
...
@@ -128,6 +128,8 @@ class SchedulerOutputs:
ignored_seq_groups
:
List
[
SequenceGroup
]
ignored_seq_groups
:
List
[
SequenceGroup
]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
num_lookahead_slots
:
int
# The number of requests in the running queue
running_queue_size
:
int
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Swap in and swap out should never happen at the same time.
# Swap in and swap out should never happen at the same time.
...
@@ -797,6 +799,7 @@ class Scheduler:
...
@@ -797,6 +799,7 @@ class Scheduler:
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
swapped_in
.
infeasible_seq_groups
,
swapped_in
.
infeasible_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
)
)
def
_schedule_chunked_prefill
(
self
):
def
_schedule_chunked_prefill
(
self
):
...
@@ -883,6 +886,7 @@ class Scheduler:
...
@@ -883,6 +886,7 @@ class Scheduler:
swapped_in
.
blocks_to_copy
),
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
)
)
def
_schedule
(
self
)
->
SchedulerOutputs
:
def
_schedule
(
self
)
->
SchedulerOutputs
:
...
...
vllm/engine/async_llm_engine.py
View file @
bc8ad684
...
@@ -16,7 +16,7 @@ from vllm.logger import init_logger
...
@@ -16,7 +16,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -210,12 +210,16 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -210,12 +210,16 @@ class _AsyncLLMEngine(LLMEngine):
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
# Execute the model.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
)
output
=
await
self
.
model_executor
.
execute_model_async
(
output
=
await
self
.
model_executor
.
execute_model_async
(
seq_group_metadata_list
,
execute_model_req
)
scheduler_outputs
.
blocks_to_swap_in
,
scheduler_outputs
.
blocks_to_swap_out
,
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
)
else
:
else
:
output
=
[]
output
=
[]
...
...
vllm/engine/llm_engine.py
View file @
bc8ad684
...
@@ -22,8 +22,8 @@ from vllm.logger import init_logger
...
@@ -22,8 +22,8 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
Sequence
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
,
SequenceGroup
,
SequenceGroupMetadata
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
...
@@ -583,12 +583,16 @@ class LLMEngine:
...
@@ -583,12 +583,16 @@ class LLMEngine:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
output
=
self
.
model_executor
.
e
xecute
_m
odel
(
execute_model_req
=
E
xecute
M
odel
Request
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
)
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
)
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
else
:
else
:
output
=
[]
output
=
[]
...
...
vllm/executor/cpu_executor.py
View file @
bc8ad684
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
import
torch
import
torch
...
@@ -7,7 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
...
@@ -7,7 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
make_async
)
...
@@ -72,18 +72,10 @@ class CPUExecutor(ExecutorBase):
...
@@ -72,18 +72,10 @@ class CPUExecutor(ExecutorBase):
logger
.
info
(
"# CPU blocks: %d"
,
num_gpu_blocks
)
logger
.
info
(
"# CPU blocks: %d"
,
num_gpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
self
,
def
execute_model
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
self
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_out
:
Dict
[
int
,
int
],
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
return
output
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
@@ -104,19 +96,10 @@ class CPUExecutor(ExecutorBase):
...
@@ -104,19 +96,10 @@ class CPUExecutor(ExecutorBase):
class
CPUExecutorAsync
(
CPUExecutor
,
ExecutorAsyncBase
):
class
CPUExecutorAsync
(
CPUExecutor
,
ExecutorAsyncBase
):
async
def
execute_model_async
(
async
def
execute_model_async
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_in
:
Dict
[
int
,
int
],
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
blocks_to_swap_out
:
Dict
[
int
,
int
],
)(
execute_model_req
=
execute_model_req
,
)
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
)
return
output
return
output
async
def
check_health_async
(
self
)
->
None
:
async
def
check_health_async
(
self
)
->
None
:
...
...
vllm/executor/executor_base.py
View file @
bc8ad684
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VisionLanguageConfig
)
SpeculativeConfig
,
VisionLanguageConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
class
ExecutorBase
(
ABC
):
class
ExecutorBase
(
ABC
):
...
@@ -68,12 +68,9 @@ class ExecutorBase(ABC):
...
@@ -68,12 +68,9 @@ class ExecutorBase(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
execute_model
(
self
,
def
execute_model
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
self
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
"""Executes at least one model step on the given sequences."""
"""Executes at least one model step on the given sequences."""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -107,13 +104,8 @@ class ExecutorAsyncBase(ExecutorBase):
...
@@ -107,13 +104,8 @@ class ExecutorAsyncBase(ExecutorBase):
@
abstractmethod
@
abstractmethod
async
def
execute_model_async
(
async
def
execute_model_async
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
"""Executes one model step on the given sequences."""
"""Executes one model step on the given sequences."""
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/executor/gpu_executor.py
View file @
bc8ad684
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
@@ -117,20 +117,9 @@ class GPUExecutor(ExecutorBase):
...
@@ -117,20 +117,9 @@ class GPUExecutor(ExecutorBase):
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_in
:
Dict
[
int
,
int
],
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
,
)
return
output
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
@@ -154,16 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
...
@@ -154,16 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async
def
execute_model_async
(
async
def
execute_model_async
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
seq_group_metadata_list
=
seq_group_metadata_list
,
)(
execute_model_req
=
execute_model_req
,
)
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
)
return
output
return
output
vllm/executor/neuron_executor.py
View file @
bc8ad684
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
make_async
from
vllm.utils
import
make_async
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -45,20 +45,18 @@ class NeuronExecutor(ExecutorBase):
...
@@ -45,20 +45,18 @@ class NeuronExecutor(ExecutorBase):
"""
"""
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
self
,
def
execute_model
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
self
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_out
:
Dict
[
int
,
int
],
assert
(
execute_model_req
.
blocks_to_swap_in
==
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
and
execute_model_req
.
blocks_to_swap_out
==
{}
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
and
execute_model_req
.
blocks_to_copy
==
{}),
(
assert
(
blocks_to_swap_in
==
{}
and
blocks_to_swap_out
==
{}
and
blocks_to_copy
==
{}),
(
"Cache operations are not supported for Neuron backend."
)
"Cache operations are not supported for Neuron backend."
)
assert
num_lookahead_slots
==
0
,
(
assert
execute_model_req
.
num_lookahead_slots
==
0
,
(
"lookahead not supported for Neuron backend."
)
"lookahead not supported for Neuron backend."
)
output
=
self
.
driver_worker
.
execute_model
(
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
)
execute_model_req
.
seq_group_metadata_list
)
return
output
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
@@ -80,14 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
...
@@ -80,14 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
async
def
execute_model_async
(
async
def
execute_model_async
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
output
=
await
make_async
(
seq_group_metadata_list
=
seq_group_metadata_list
,
)
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
)
return
output
return
output
async
def
check_health_async
(
self
)
->
None
:
async
def
check_health_async
(
self
)
->
None
:
...
...
vllm/executor/ray_gpu_executor.py
View file @
bc8ad684
...
@@ -10,7 +10,7 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
...
@@ -10,7 +10,7 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
get_vllm_instance_id
,
make_async
)
...
@@ -166,21 +166,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -166,21 +166,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers
=
self
.
parallel_config
.
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
)
max_parallel_loading_workers
)
def
execute_model
(
self
,
def
execute_model
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
self
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
=
0
)
->
List
[
SamplerOutput
]:
all_outputs
=
self
.
_run_workers
(
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
"execute_model"
,
driver_kwargs
=
{
driver_kwargs
=
{
"execute_model_req"
:
execute_model_req
},
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
"num_lookahead_slots"
:
num_lookahead_slots
,
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
# Only the driver worker returns the sampling results.
# Only the driver worker returns the sampling results.
...
...
vllm/sequence.py
View file @
bc8ad684
"""Sequence and its related classes."""
"""Sequence and its related classes."""
import
copy
import
copy
import
enum
import
enum
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
from
vllm.block
import
LogicalTokenBlock
from
vllm.block
import
LogicalTokenBlock
...
@@ -734,3 +734,33 @@ class SamplerOutput:
...
@@ -734,3 +734,33 @@ class SamplerOutput:
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
@
dataclass
class
ExecuteModelRequest
:
"""The model execution request."""
# The sequence group metadata list.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
field
(
default_factory
=
dict
)
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
=
0
# The number of requests in the running queue.
running_queue_size
:
int
=
0
def
clone
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
"ExecuteModelRequest"
:
"""Clone the request with a new sequence group metadata list."""
return
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
self
.
blocks_to_swap_in
.
copy
(),
blocks_to_swap_out
=
self
.
blocks_to_swap_out
.
copy
(),
blocks_to_copy
=
self
.
blocks_to_copy
.
copy
(),
num_lookahead_slots
=
self
.
num_lookahead_slots
,
running_queue_size
=
self
.
running_queue_size
,
)
vllm/spec_decode/batch_expansion.py
View file @
bc8ad684
from
itertools
import
chain
,
count
from
itertools
import
chain
,
count
from
typing
import
Dict
,
Iterator
,
List
,
Optional
,
Tuple
from
typing
import
Iterator
,
List
,
Tuple
import
torch
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
...
@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
def
score_proposals
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
)
->
SpeculativeScores
:
"""Score the proposed tokens via the scorer model.
"""Score the proposed tokens via the scorer model.
...
@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
no speculation is produced for that sequence.
no speculation is produced for that sequence.
Args:
Args:
seq_group_metadata_list: The input sequence group metadata.
execute_model_req: The execution request.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
proposals: The speculative proposals to score.
proposals: The speculative proposals to score.
Returns:
Returns:
SpeculativeScores: The scores of each speculative token, along with
SpeculativeScores: The scores of each speculative token, along with
...
@@ -80,28 +73,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -80,28 +73,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
proposal_lens_list
=
proposal_lens_list
,
)
)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
seq_group_metadata_list
=
target_seq_group_metadata_list
,
execute_model_req
=
execute_model_req
.
clone
(
blocks_to_swap_in
=
blocks_to_swap_in
,
seq_group_metadata_list
=
target_seq_group_metadata_list
,
))
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
,
spec_logprobs
=
self
.
_contract_batch
(
all_tokens
,
all_probs
,
spec_logprobs
=
self
.
_contract_batch
(
contracted_bs
=
len
(
seq_group_metadata_list
),
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
spec_indices
=
spec_indices
,
k
=
k
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
)
return
SpeculativeScores
(
return
SpeculativeScores
(
...
...
vllm/spec_decode/interfaces.py
View file @
bc8ad684
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
@
dataclass
@
dataclass
...
@@ -58,11 +57,7 @@ class SpeculativeProposer(ABC):
...
@@ -58,11 +57,7 @@ class SpeculativeProposer(ABC):
@
abstractmethod
@
abstractmethod
def
get_proposals
(
def
get_proposals
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
)
->
SpeculativeProposals
:
)
->
SpeculativeProposals
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -72,11 +67,7 @@ class SpeculativeScorer(ABC):
...
@@ -72,11 +67,7 @@ class SpeculativeScorer(ABC):
@
abstractmethod
@
abstractmethod
def
score_proposals
(
def
score_proposals
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
)
->
SpeculativeScores
:
raise
NotImplementedError
raise
NotImplementedError
vllm/spec_decode/multi_step_worker.py
View file @
bc8ad684
import
copy
import
copy
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
List
,
Tuple
import
torch
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
...
@@ -44,10 +45,7 @@ class MultiStepWorker(Worker):
...
@@ -44,10 +45,7 @@ class MultiStepWorker(Worker):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
sampler_output
(
def
sampler_output
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
sample_len
:
int
,
sample_len
:
int
,
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
"""Run the model forward pass sample_len times. Returns the list of
...
@@ -57,26 +55,24 @@ class MultiStepWorker(Worker):
...
@@ -57,26 +55,24 @@ class MultiStepWorker(Worker):
For multi step worker, this indicator shall be True.
For multi step worker, this indicator shall be True.
"""
"""
self
.
_raise_if_unsupported
(
seq_group_metadata_list
,
blocks_to_swap_in
,
self
.
_raise_if_unsupported
(
execute_model_req
)
blocks_to_swap_out
,
blocks_to_copy
)
# Shallow copy input data so modifications (such as appending tokens)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
# do not cause side-effects.
copied_seq_group_metadata_list
=
self
.
_shallow_copy_inputs
(
copied_seq_group_metadata_list
=
self
.
_shallow_copy_inputs
(
seq_group_metadata_list
)
execute_model_req
.
seq_group_metadata_list
)
copied_execute_model_req
=
execute_model_req
.
clone
(
copied_seq_group_metadata_list
)
# Assert enough KV space for sample_len tokens per sequence.
# Assert enough KV space for sample_len tokens per sequence.
self
.
_assert_enough_kv_space
(
seq_group_metadata_list
,
sample_len
)
self
.
_assert_enough_kv_space
(
execute_model_req
.
seq_group_metadata_list
,
sample_len
)
# Run model sample_len times.
# Run model sample_len times.
model_outputs
=
[]
model_outputs
=
[]
for
_
in
range
(
sample_len
):
for
_
in
range
(
sample_len
):
model_output
=
super
().
execute_model
(
model_output
=
super
().
execute_model
(
seq_group_metadata_list
=
copied_seq_group_metadata_list
,
execute_model_req
=
copied_execute_model_req
)
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
assert
(
len
(
model_output
)
==
1
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
model_output
=
model_output
[
0
]
...
@@ -89,23 +85,13 @@ class MultiStepWorker(Worker):
...
@@ -89,23 +85,13 @@ class MultiStepWorker(Worker):
def
get_spec_proposals
(
def
get_spec_proposals
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
)
->
SpeculativeProposals
:
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
speculative tokens per sequence is determined by max_proposal_len.
"""
"""
return
self
.
_proposer
.
get_proposals
(
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
max_proposal_len
,
)
def
_append_new_tokens
(
def
_append_new_tokens
(
self
,
model_output
:
SamplerOutput
,
self
,
model_output
:
SamplerOutput
,
...
@@ -196,20 +182,22 @@ class MultiStepWorker(Worker):
...
@@ -196,20 +182,22 @@ class MultiStepWorker(Worker):
def
_raise_if_unsupported
(
def
_raise_if_unsupported
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
)
->
None
:
"""MultiStepWorker does not yet implement support for cache swap
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
operations or beam search.
"""
"""
if
any
([
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]):
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
raise
NotImplementedError
(
"MultiStepWorker does not support cache operations"
)
"MultiStepWorker does not support cache operations"
)
if
any
(
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
seq_group_metadata_list
):
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"MultiStepWorker does not support beam search."
)
"MultiStepWorker does not support beam search."
)
vllm/spec_decode/ngram_worker.py
View file @
bc8ad684
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
...
@@ -46,13 +46,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -46,13 +46,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# NGram don't need gpu sampler
# NGram don't need gpu sampler
pass
pass
def
execute_model
(
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
None
:
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
)
->
None
:
"""NGram doesn't depend on model execution, just pass this function"""
"""NGram doesn't depend on model execution, just pass this function"""
pass
pass
...
@@ -71,10 +65,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -71,10 +65,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
def
sampler_output
(
def
sampler_output
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
sample_len
:
int
,
sample_len
:
int
,
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
"""NGram match algo to pick proposal candidate. Returns the list of
"""NGram match algo to pick proposal candidate. Returns the list of
...
@@ -83,16 +74,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -83,16 +74,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
For ngram worker, we already done needed transposed internal, so the
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
indicator pass to sampler_output_to_torch shall be False.
"""
"""
self
.
_raise_if_unsupported
(
self
.
_raise_if_unsupported
(
execute_model_req
)
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
)
arr
=
[]
arr
=
[]
has_spec_out
=
False
has_spec_out
=
False
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
:
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
...
@@ -135,17 +121,19 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -135,17 +121,19 @@ class NGramWorker(LoraNotSupportedWorkerBase):
indices
=
token_ids
.
unsqueeze
(
2
)
indices
=
token_ids
.
unsqueeze
(
2
)
token_probs
=
torch
.
zeros
(
token_probs
=
torch
.
zeros
(
(
len
(
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
(
len
(
execute_model_req
.
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
token_logprobs
=
torch
.
zeros
(
token_logprobs
=
torch
.
zeros
(
(
len
(
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
(
len
(
execute_model_req
.
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
for
i
in
range
(
len
(
seq_group_metadata_list
)):
for
i
in
range
(
len
(
execute_model_req
.
seq_group_metadata_list
)):
outputs
.
append
(
outputs
.
append
(
SamplerOutput
(
SamplerOutput
(
outputs
=
None
,
outputs
=
None
,
...
@@ -157,40 +145,32 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -157,40 +145,32 @@ class NGramWorker(LoraNotSupportedWorkerBase):
def
get_spec_proposals
(
def
get_spec_proposals
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
)
->
SpeculativeProposals
:
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
speculative tokens per sequence is determined by max_proposal_len.
"""
"""
return
self
.
_proposer
.
get_proposals
(
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
max_proposal_len
,
)
def
_raise_if_unsupported
(
def
_raise_if_unsupported
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
)
->
None
:
"""NGramWorker does not yet implement support for cache swap
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
operations or beam search.
"""
"""
if
any
([
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]):
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
raise
NotImplementedError
(
"NGramWorker does not support cache operations"
)
"NGramWorker does not support cache operations"
)
if
any
(
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
seq_group_metadata_list
):
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"NGramWorker does not support beam search."
)
"NGramWorker does not support beam search."
)
vllm/spec_decode/spec_decode_worker.py
View file @
bc8ad684
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
...
@@ -189,69 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -189,69 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
"""Perform speculative decoding on the input batch.
"""Perform speculative decoding on the input batch.
"""
"""
assert
seq_group_metadata_list
is
not
None
,
(
assert
execute_model_req
.
seq_group_metadata_list
is
not
None
,
(
"speculative decoding "
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
"requires non-None seq_group_metadata_list"
)
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
# num_lookahead_slots)
# If no spec tokens, call the proposer and scorer workers normally.
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
# Used for prefill.
if
num_lookahead_slots
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
if
execute_model_req
.
num_lookahead_slots
==
0
or
len
(
return
self
.
_run_no_spec
(
execute_model_req
.
seq_group_metadata_list
)
==
0
:
seq_group_metadata_list
=
seq_group_metadata_list
,
return
self
.
_run_no_spec
(
execute_model_req
)
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
return
self
.
_run_speculative_decoding_step
(
execute_model_req
)
blocks_to_copy
=
blocks_to_copy
,
)
return
self
.
_run_speculative_decoding_step
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
k
=
num_lookahead_slots
,
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
def
_run_no_spec
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
)
->
List
[
SamplerOutput
]:
"""Run a prefill step, without any speculation. The input is sent to the
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
proposer and scorer model so that the KV cache is consistent between the
two.
two.
"""
"""
#logger.info("run proposer worker no spec")
#logger.info("run proposer worker no spec")
self
.
proposer_worker
.
execute_model
(
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
#logger.info("run target worker no spec")
#logger.info("run target worker no spec")
sampler_output
=
self
.
scorer_worker
.
execute_model
(
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
assert
len
(
sampler_output
)
==
1
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
sampler_output
=
sampler_output
[
0
]
...
@@ -264,13 +233,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -264,13 +233,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
def
_run_speculative_decoding_step
(
def
_run_speculative_decoding_step
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
)
->
List
[
SamplerOutput
]:
"""Execute a single step of speculative decoding.
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
This invokes the proposer worker to get k speculative tokens for each
...
@@ -282,33 +246,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -282,33 +246,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
#logger.info("get spec proposals")
#logger.info("get spec proposals")
# Generate proposals using draft worker.
# Generate proposals using draft worker.
assert
blocks_to_swap_in
is
not
None
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
)
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
)
#logger.info("score proposals")
#logger.info("score proposals")
proposal_scores
=
self
.
scorer
.
score_proposals
(
proposal_scores
=
self
.
scorer
.
score_proposals
(
seq_group_metadata_list
,
execute_model_req
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
,
proposals
,
proposals
,
)
)
#logger.info("verify proposals")
#logger.info("verify proposals")
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
#logger.info("create output list")
#logger.info("create output list")
return
self
.
_create_output_sampler_list
(
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
accepted_token_ids
,
accepted_token_ids
,
target_logprobs
=
target_logprobs
,
target_logprobs
=
target_logprobs
,
k
=
k
)
k
=
execute_model_req
.
num_lookahead_slots
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
def
_verify_tokens
(
...
...
vllm/spec_decode/top1_proposer.py
View file @
bc8ad684
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
SpeculativeProposer
)
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.spec_decode.util
import
sampler_output_to_torch
...
@@ -40,17 +41,15 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -40,17 +41,15 @@ class Top1Proposer(SpeculativeProposer):
def
get_proposals
(
def
get_proposals
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
execute_model_req
:
ExecuteModelRequest
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
proposal_len
:
int
,
)
->
SpeculativeProposals
:
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
Sequences which would exceed the max model length are skipped during
speculation.
speculation.
"""
"""
proposal_len
=
execute_model_req
.
num_lookahead_slots
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
# Split speculative- and non-speculative- sequences.
(
(
...
@@ -66,11 +65,12 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -66,11 +65,12 @@ class Top1Proposer(SpeculativeProposer):
# token_ids is like [batch] format in proposal_len size list,
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# while if it is false, the format would be [proposal_len]
# in batch size list
# in batch size list
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_outpu
t
(
nonzero_execute_model_req
=
ExecuteModelReques
t
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
blocks_to_swap_in
=
blocks_to_swap_in
,
num_lookahead_slots
=
proposal_len
,
blocks_to_swap_out
=
blocks_to_swap_out
,
)
blocks_to_copy
=
blocks_to_copy
,
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_output
(
execute_model_req
=
nonzero_execute_model_req
,
sample_len
=
proposal_len
,
sample_len
=
proposal_len
,
)
)
else
:
else
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment