Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e95cd879
Unverified
Commit
e95cd879
authored
Apr 16, 2024
by
Cade Daniel
Committed by
GitHub
Apr 16, 2024
Browse files
[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)
parent
69e1d2fb
Changes
31
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1207 additions
and
374 deletions
+1207
-374
tests/core/block/e2e/test_correctness.py
tests/core/block/e2e/test_correctness.py
+70
-0
tests/core/utils.py
tests/core/utils.py
+10
-7
tests/engine/output_processor/test_multi_step.py
tests/engine/output_processor/test_multi_step.py
+270
-0
tests/spec_decode/e2e/test_correctness.py
tests/spec_decode/e2e/test_correctness.py
+114
-13
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+2
-2
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+20
-12
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+1
-1
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+0
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+2
-6
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+3
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+53
-318
vllm/engine/output_processor/__init__.py
vllm/engine/output_processor/__init__.py
+0
-0
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+69
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+126
-0
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+276
-0
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+101
-0
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+16
-0
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+2
-1
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+3
-2
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+69
-10
No files found.
tests/core/block/e2e/test_correctness.py
View file @
e95cd879
...
...
@@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
assert
baseline_token_ids
==
test_token_ids
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[
{
# Use a small model for a fast test.
"model"
:
"facebook/opt-125m"
,
# skip cuda graph creation for fast test.
"enforce_eager"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
2
,
"max_num_seqs"
:
2
,
},
])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[
{
"use_v2_block_manager"
:
False
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"use_v2_block_manager"
:
True
,
"num_lookahead_slots"
:
0
,
},
{
"use_v2_block_manager"
:
True
,
"num_lookahead_slots"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_chunked_prefill_block_manager_v2
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
):
"""Verify that chunked prefill works with BlockManagerV2, with and without
lookahead scheduling.
"""
output_len
=
32
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
print
(
'Getting token ids with BlockManagerV1'
)
baseline_token_ids
=
get_token_ids_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
print
(
'Getting token ids with BlockManagerV2'
)
test_token_ids
=
get_token_ids_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
for
expected_token_ids
,
actual_token_ids
in
zip
(
baseline_token_ids
,
test_token_ids
):
assert
expected_token_ids
==
actual_token_ids
assert
baseline_token_ids
==
test_token_ids
def
get_token_ids_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
):
for
llm
in
llm_generator
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
...
...
tests/core/utils.py
View file @
e95cd879
import
time
from
typing
import
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
from
vllm
import
SamplingParams
from
vllm.lora.request
import
LoRARequest
...
...
@@ -31,14 +31,17 @@ def create_dummy_prompt(
def
create_seq_group
(
seq_prompt_len
=
1024
,
seq_output_lens
=
(
128
,
),
request_id
=
'0'
,
seq_id_start
=
0
,
)
->
SequenceGroup
:
seq_prompt_len
:
int
=
1024
,
seq_output_lens
:
Iterable
[
int
]
=
(
128
,
),
request_id
:
str
=
'0'
,
seq_id_start
:
int
=
0
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
)
->
SequenceGroup
:
assert
len
(
seq_output_lens
)
>
0
if
sampling_params
is
None
:
sampling_params
=
SamplingParams
()
prompt_token_ids
=
[
0
]
*
seq_prompt_len
seqs
=
[]
...
...
@@ -60,7 +63,7 @@ def create_seq_group(
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
seqs
,
sampling_params
=
S
ampling
P
arams
()
,
sampling_params
=
s
ampling
_p
arams
,
arrival_time
=
time
.
time
(),
)
...
...
tests/engine/output_processor/test_multi_step.py
0 → 100644
View file @
e95cd879
import
random
from
unittest.mock
import
MagicMock
import
pytest
from
transformers
import
PreTrainedTokenizer
from
tests.core.utils
import
create_seq_group
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.multi_step
import
MultiStepOutputProcessor
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Logprob
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
@
pytest
.
mark
.
parametrize
(
"seq_output_len"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_new_tokens"
,
[
1
,
12
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_appends_token_ids
(
num_new_tokens
:
int
,
seq_output_len
:
int
):
"""Verify multi-step decoding appends token ids correctly.
We append token ids and verify all the token ids were appended correctly.
Note that ignore_eos=True.
"""
detokenizer
=
MagicMock
(
spec
=
Detokenizer
)
scheduler
=
MagicMock
(
spec
=
Scheduler
)
stop_checker
=
MagicMock
(
spec
=
StopChecker
)
seq_counter
=
Counter
()
output_processor
=
MultiStepOutputProcessor
(
detokenizer
=
detokenizer
,
scheduler
=
scheduler
,
seq_counter
=
seq_counter
,
get_tokenizer_for_seq
=
lambda
_
:
mock_tokenizer
(),
stop_checker
=
stop_checker
,
)
seq_group
=
create_seq_group
(
seq_prompt_len
=
1024
,
seq_output_lens
=
[
seq_output_len
],
sampling_params
=
SamplingParams
(
max_tokens
=
seq_output_len
+
num_new_tokens
,
ignore_eos
=
True
),
)
seq
=
seq_group
.
get_seqs
()[
0
]
seq
.
status
=
SequenceStatus
.
RUNNING
new_token_ids
=
list
(
range
(
num_new_tokens
))
outputs
=
[
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq
.
seq_id
,
output_token
=
output_token
,
logprobs
=
{
output_token
:
Logprob
(
0.0
)},
)
],
prompt_logprobs
=
None
,
)
for
output_token
in
new_token_ids
]
assert
seq
.
get_token_ids
()[
-
len
(
new_token_ids
):]
!=
new_token_ids
output_processor
.
process_outputs
(
seq_group
,
outputs
)
assert
seq
.
get_token_ids
()[
-
len
(
new_token_ids
):]
==
new_token_ids
@
pytest
.
mark
.
parametrize
(
"seq_prompt_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"seq_output_len"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_new_tokens"
,
[
5
,
6
,
7
,
8
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
+
3
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_respects_max_tokens
(
num_new_tokens
:
int
,
seq_prompt_len
:
int
,
seq_output_len
:
int
,
max_tokens
:
int
):
"""Verify tokens after max_tokens are dropped and not appended to the
sequence.
"""
detokenizer
=
MagicMock
(
spec
=
Detokenizer
)
scheduler
=
MagicMock
(
spec
=
Scheduler
)
stop_checker
=
MagicMock
(
spec
=
StopChecker
)
seq_counter
=
Counter
()
output_processor
=
MultiStepOutputProcessor
(
detokenizer
=
detokenizer
,
scheduler
=
scheduler
,
seq_counter
=
seq_counter
,
get_tokenizer_for_seq
=
lambda
_
:
mock_tokenizer
(),
stop_checker
=
stop_checker
,
)
seq_group
=
create_seq_group
(
seq_prompt_len
=
seq_prompt_len
,
seq_output_lens
=
[
seq_output_len
],
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
),
)
seq
=
seq_group
.
get_seqs
()[
0
]
seq
.
status
=
SequenceStatus
.
RUNNING
new_token_ids
=
list
(
range
(
num_new_tokens
))
outputs
=
[
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq
.
seq_id
,
output_token
=
output_token
,
logprobs
=
{
output_token
:
Logprob
(
0.0
)},
)
],
prompt_logprobs
=
None
,
)
for
output_token
in
new_token_ids
]
assert
seq
.
get_len
()
==
seq_prompt_len
+
seq_output_len
output_processor
.
process_outputs
(
seq_group
,
outputs
)
# Expect the processed sequence to not go over max tokens in len.
assert
seq
.
get_len
()
==
seq_prompt_len
+
max_tokens
# Expect the correct tokens were appended.
expected_appended_tokens
=
new_token_ids
[:
max_tokens
-
seq_output_len
]
assert
seq
.
get_token_ids
(
)[
-
len
(
expected_appended_tokens
):]
==
expected_appended_tokens
@
pytest
.
mark
.
parametrize
(
"seq_prompt_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"seq_output_len"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_new_tokens"
,
[
12
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
6
)))
@
pytest
.
mark
.
skip_global_cleanup
def
test_respects_eos_token_id
(
num_new_tokens
:
int
,
seq_prompt_len
:
int
,
seq_output_len
:
int
,
seed
:
int
):
"""Verify the eos token id is included in the sequence, but subsequent
tokens are dropped (not appended to sequence).
"""
random
.
seed
(
seed
)
detokenizer
=
MagicMock
(
spec
=
Detokenizer
)
scheduler
=
MagicMock
(
spec
=
Scheduler
)
stop_checker
=
MagicMock
(
spec
=
StopChecker
)
seq_counter
=
Counter
()
eos_token_id
=
100
output_processor
=
MultiStepOutputProcessor
(
detokenizer
=
detokenizer
,
scheduler
=
scheduler
,
seq_counter
=
seq_counter
,
get_tokenizer_for_seq
=
lambda
_
:
mock_tokenizer
(
eos_token_id
),
stop_checker
=
stop_checker
,
)
seq_group
=
create_seq_group
(
seq_prompt_len
=
seq_prompt_len
,
seq_output_lens
=
[
seq_output_len
],
sampling_params
=
SamplingParams
(
# Ensure enough space.
max_tokens
=
seq_output_len
+
num_new_tokens
,
),
)
seq
=
seq_group
.
get_seqs
()[
0
]
seq
.
status
=
SequenceStatus
.
RUNNING
new_token_ids
=
list
(
range
(
num_new_tokens
))
assert
eos_token_id
not
in
new_token_ids
eos_index
=
random
.
randint
(
0
,
len
(
new_token_ids
)
-
1
)
new_token_ids
[
eos_index
]
=
eos_token_id
outputs
=
[
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq
.
seq_id
,
output_token
=
output_token
,
logprobs
=
{
output_token
:
Logprob
(
0.0
)},
)
],
prompt_logprobs
=
None
,
)
for
output_token
in
new_token_ids
]
assert
seq
.
get_len
()
==
seq_prompt_len
+
seq_output_len
output_processor
.
process_outputs
(
seq_group
,
outputs
)
# Expect the processed sequence to not go beyond provided eos.
assert
seq
.
get_len
()
==
seq_prompt_len
+
seq_output_len
+
(
eos_index
+
1
)
# Expect the correct tokens were appended.
expected_appended_tokens
=
new_token_ids
[:
eos_index
+
1
]
assert
seq
.
get_token_ids
(
)[
-
len
(
expected_appended_tokens
):]
==
expected_appended_tokens
@
pytest
.
mark
.
parametrize
(
"seq_prompt_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"seq_output_len"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_new_tokens"
,
[
12
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
6
)))
@
pytest
.
mark
.
skip_global_cleanup
def
test_ignores_eos_token_id
(
num_new_tokens
:
int
,
seq_prompt_len
:
int
,
seq_output_len
:
int
,
seed
:
int
):
"""When sampling parameters dictate that we should ignore the eos token id,
ensure all token ids are appended even if the eos token id is emitted.
"""
random
.
seed
(
seed
)
detokenizer
=
MagicMock
(
spec
=
Detokenizer
)
scheduler
=
MagicMock
(
spec
=
Scheduler
)
stop_checker
=
MagicMock
(
spec
=
StopChecker
)
seq_counter
=
Counter
()
eos_token_id
=
100
output_processor
=
MultiStepOutputProcessor
(
detokenizer
=
detokenizer
,
scheduler
=
scheduler
,
seq_counter
=
seq_counter
,
get_tokenizer_for_seq
=
lambda
_
:
mock_tokenizer
(
eos_token_id
),
stop_checker
=
stop_checker
,
)
seq_group
=
create_seq_group
(
seq_prompt_len
=
seq_prompt_len
,
seq_output_lens
=
[
seq_output_len
],
sampling_params
=
SamplingParams
(
# Ensure enough space.
max_tokens
=
seq_output_len
+
num_new_tokens
,
ignore_eos
=
True
,
),
)
seq
=
seq_group
.
get_seqs
()[
0
]
seq
.
status
=
SequenceStatus
.
RUNNING
new_token_ids
=
list
(
range
(
num_new_tokens
))
assert
eos_token_id
not
in
new_token_ids
eos_index
=
random
.
randint
(
0
,
len
(
new_token_ids
)
-
1
)
new_token_ids
[
eos_index
]
=
eos_token_id
outputs
=
[
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq
.
seq_id
,
output_token
=
output_token
,
logprobs
=
{
output_token
:
Logprob
(
0.0
)},
)
],
prompt_logprobs
=
None
,
)
for
output_token
in
new_token_ids
]
assert
seq
.
get_len
()
==
seq_prompt_len
+
seq_output_len
output_processor
.
process_outputs
(
seq_group
,
outputs
)
# Expect the processed sequence to go beyond eos.
assert
seq
.
get_len
()
==
seq_prompt_len
+
seq_output_len
+
num_new_tokens
# Expect the correct tokens were appended.
expected_appended_tokens
=
new_token_ids
[:
seq_output_len
+
num_new_tokens
-
seq_output_len
]
assert
seq
.
get_token_ids
(
)[
-
len
(
expected_appended_tokens
):]
==
expected_appended_tokens
def
mock_tokenizer
(
eos_token_id
=
1000
):
tokenizer
=
MagicMock
(
spec
=
PreTrainedTokenizer
)
tokenizer
.
eos_token_id
=
eos_token_id
return
tokenizer
tests/spec_decode/e2e/test_correctness.py
View file @
e95cd879
from
itertools
import
cycle
from
typing
import
List
,
Tuple
import
pytest
from
transformers
import
AutoTokenizer
from
vllm
import
SamplingParams
...
...
@@ -7,18 +11,47 @@ from vllm import SamplingParams
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
"model"
:
"facebook/opt-125m"
,
"speculative_model"
:
"facebook/opt-125m"
,
"num_speculative_tokens"
:
5
,
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
# Skip real loading for fast test.
"load_format"
:
"dummy"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
1
,
},
{
# No spec decode.
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
# NOTE: We should run more permutations of this test (more BS, more seeds). But
# because our spec decode generates gibberish token ids, the likelihood of
# emitting an invalid token combination is nontrivial. This causes divergence in
# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
# start" bytes are emitted.
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_config
(
test_llm_generator
):
output_len
=
1024
def
test_spec_decode_e2e_logical_flow
(
test_llm_generator
,
batch_size
:
int
):
"""Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
"""
output_len
=
32
temperature
=
0.0
prompts
=
[
...
...
@@ -28,23 +61,91 @@ def test_spec_decode_config(test_llm_generator):
"The future of AI is"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
,
)
batch_tokens
,
batch_token_ids
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
# Expect a generation for each prompt in the batch.
assert
len
(
batch_token_ids
)
==
len
(
prompts
)
# Expect each generation to have expected number of tokens (note
# ignore_eos=True).
assert
all
(
len
(
token_ids
)
==
output_len
for
token_ids
in
batch_token_ids
)
# Expect detokenized string to match.
tok
=
AutoTokenizer
.
from_pretrained
(
"JackFram/llama-68m"
)
for
actual_tokens
,
actual_token_ids
in
zip
(
batch_tokens
,
batch_token_ids
):
expected_tokens
=
tok
.
decode
(
actual_token_ids
)
print
(
f
"
{
actual_token_ids
=
}
"
)
assert
actual_tokens
.
strip
()
==
expected_tokens
.
strip
()
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
"model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
# Skip real loading for fast test.
"load_format"
:
"dummy"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_xfail
(
test_llm_generator
):
"""Verify that speculative decoding with Ray fails.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
AssertionError
,
match
=
"Speculative decoding not yet supported for GPU backend"
):
get_token_ids_from_llm_generator
(
test_llm_generator
,
prompts
,
with
pytest
.
raises
(
AssertionError
,
match
=
"Speculative decoding not yet supported for "
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
def
get_token_ids_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
):
def
get_output_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
Tuple
[
List
[
str
],
List
[
List
[
int
]]]:
for
llm
in
llm_generator
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
tokens
=
[
output
.
outputs
[
0
].
text
for
output
in
outputs
]
del
llm
return
token_ids
return
tokens
,
token_ids
tests/spec_decode/test_multi_step_worker.py
View file @
e95cd879
...
...
@@ -125,7 +125,7 @@ def test_same_output_for_single_step():
zero_kv_cache
(
worker
.
cache_engine
)
set_random_seed
(
seed
)
expected_output
=
worker
.
execute_model
(
**
single_step_execute_model_data
.
to_dict
(),
)
**
single_step_execute_model_data
.
to_dict
(),
)
[
0
]
actual_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
actual_output
...
...
@@ -219,7 +219,7 @@ def test_same_output_for_multi_step():
continuations
=
continuations
,
final_seq_lens
=
final_seq_lens
))
single_step_output
.
app
end
(
single_step_output
.
ext
end
(
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
))
# Append output tokens to new sequence data.
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
e95cd879
...
...
@@ -6,6 +6,7 @@ import torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
...
...
@@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_spec_tokens
=
k
)
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
call_args_list
=
draft_worker
.
get_spec_proposals
.
call_args_list
assert
len
(
call_args_list
)
==
1
...
...
@@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker
.
execute_model
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_spec_tokens
=
k
)
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
seen_contexts
=
[]
...
...
@@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_worker
.
execute_model
.
return_value
=
target_output
[
0
]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]
]
exception_secret
=
'artifical stop'
rejection_sampler
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_spec_tokens
=
k
)
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_lookahead_slots
=
k
)
assert
len
(
rejection_sampler
.
call_args_list
)
==
1
args
,
_
=
rejection_sampler
.
call_args_list
[
0
]
...
...
@@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_worker
.
execute_model
.
return_value
=
target_output
[
0
]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]
]
rejection_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
...
...
@@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler
.
return_value
=
rejection_sampler_output
output
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_
spec_token
s
=
k
)
num_
lookahead_slot
s
=
k
)
expected_output
=
create_sampler_output_list
(
rejection_sampler_output
.
transpose
(
0
,
1
),
[
None
for
_
in
range
(
k
+
1
)])
...
...
@@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_worker
.
execute_model
.
return_value
=
target_output
[
0
]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]
]
rejection_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
...
...
@@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
mock_rejsample_metrics
)
output
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_
spec_token
s
=
k
)
num_
lookahead_slot
s
=
k
)
assert
output
[
0
].
spec_decode_worker_metrics
==
mock_rejsample_metrics
call_args_list
=
(
...
...
@@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int):
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
execute_model
.
return_value
=
[
MagicMock
(
spec
=
SamplerOutput
)]
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
...
...
@@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int):
batch_size
,
k
,
prev_output_token_len
=
0
)
out
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_
spec_token
s
=
k
)
num_
lookahead_slot
s
=
k
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
...
...
@@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int):
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
()
,
return_python_output
=
False
)
**
execute_model_data
.
to_dict
())
target_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
...
...
@@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int):
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
execute_model
.
return_value
=
[
MagicMock
(
spec
=
SamplerOutput
)]
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
...
...
@@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int):
batch_size
,
k
,
prev_output_token_len
=
0
)
out
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
num_
spec_token
s
=
k
)
num_
lookahead_slot
s
=
k
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
...
...
@@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int):
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
()
,
return_python_output
=
False
)
**
execute_model_data
.
to_dict
())
target_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
...
...
tests/spec_decode/utils.py
View file @
e95cd879
...
...
@@ -212,7 +212,7 @@ def create_sampler_output_list(
SequenceOutput
(
output_token
=
token_id
,
parent_seq_id
=
seq_ids
[
seq_index
],
logprobs
=
{
token_id
:
0
},
logprobs
=
{
token_id
:
Logprob
(
0
)
},
)
],
prompt_logprobs
=
None
,
...
...
vllm/core/block/block_table.py
View file @
e95cd879
...
...
@@ -104,7 +104,6 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert
self
.
_is_allocated
assert
token_ids
,
"can't append empty token ids"
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
num_lookahead_slots
)
...
...
vllm/core/scheduler.py
View file @
e95cd879
...
...
@@ -762,9 +762,7 @@ class Scheduler:
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
(
prefills
.
num_lookahead_slots
+
running_scheduled
.
num_lookahead_slots
+
swapped_in
.
num_lookahead_slots
),
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
)
def
_schedule_chunked_prefill
(
self
):
...
...
@@ -850,9 +848,7 @@ class Scheduler:
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
(
prefills
.
num_lookahead_slots
+
running_scheduled
.
num_lookahead_slots
+
swapped_in
.
num_lookahead_slots
),
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
)
def
_schedule
(
self
)
->
SchedulerOutputs
:
...
...
vllm/engine/async_llm_engine.py
View file @
e95cd879
...
...
@@ -217,7 +217,9 @@ class _AsyncLLMEngine(LLMEngine):
else
:
output
=
[]
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
)
async
def
encode_request_async
(
self
,
...
...
vllm/engine/llm_engine.py
View file @
e95cd879
This diff is collapsed.
Click to expand it.
vllm/engine/output_processor/__init__.py
0 → 100644
View file @
e95cd879
vllm/engine/output_processor/interfaces.py
0 → 100644
View file @
e95cd879
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
,
Iterable
,
List
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
SchedulerConfig
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceGroupOutput
from
vllm.transformers_utils.detokenizer
import
Detokenizer
class
SequenceGroupOutputProcessor
(
ABC
):
"""Interface for logic that processes new token ids in sequence groups,
managing detokenization, stop checking, and freeing/forking sequences with
the scheduler.
This is highly coupled with the LLMEngine and should be seen as an extension
of it. The logic is separated to simplify the LLMEngine class and allow
separate implementations for single-step decoding (which supports beam
search sequence forking) and multi-step decoding (which does not support
beam search, but does support speculative decoding).
"""
@
staticmethod
def
create_output_processor
(
scheduler_config
:
SchedulerConfig
,
detokenizer
:
Detokenizer
,
scheduler
:
Scheduler
,
seq_counter
:
Iterable
[
int
],
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
PreTrainedTokenizer
],
stop_checker
:
"StopChecker"
,
):
"""Create an output processor.
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if
scheduler_config
.
num_lookahead_slots
==
0
:
# Importing here to avoid cycle.
from
vllm.engine.output_processor.single_step
import
(
SingleStepOutputProcessor
)
return
SingleStepOutputProcessor
(
scheduler_config
,
detokenizer
,
scheduler
,
seq_counter
,
stop_checker
,
)
else
:
# Importing here to avoid cycle.
from
vllm.engine.output_processor.multi_step
import
(
MultiStepOutputProcessor
)
return
MultiStepOutputProcessor
(
detokenizer
,
scheduler
,
seq_counter
,
get_tokenizer_for_seq
,
stop_checker
,
)
@
abstractmethod
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
"""
pass
vllm/engine/output_processor/multi_step.py
0 → 100644
View file @
e95cd879
from
typing
import
Callable
,
Iterable
,
List
from
transformers
import
PreTrainedTokenizer
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Logprob
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
logger
=
init_logger
(
__name__
)
class
MultiStepOutputProcessor
(
SequenceGroupOutputProcessor
):
"""SequenceGroupOutputProcessor which handles logic related to
detokenization and stopping conditions. It specializes to "multi-step
decoding", where vLLM's worker may generate multiple tokens per invocation.
This is currently mutually exclusive with advanced sampling techniques like
beam search, which motivates the separation of this logic from the single
step output processor.
This class is responsible for things such as correctly appending all new
token ids to their sequence, detokenizing new token ids, truncating new
output tokens after an eos token, and correctly handling the case where the
number of new output tokens per sequence differs in a single batch.
"""
def
__init__
(
self
,
detokenizer
:
Detokenizer
,
scheduler
:
Scheduler
,
seq_counter
:
Iterable
[
int
],
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
PreTrainedTokenizer
],
stop_checker
:
StopChecker
,
):
self
.
detokenizer
=
detokenizer
self
.
scheduler
=
scheduler
self
.
seq_counter
=
seq_counter
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
stop_checker
=
stop_checker
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization,
including freeing finished sequences. It also handles cases where there
are tokens emitted after the EOS token.
"""
seqs
=
sequence_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
assert
seqs
,
"expected running sequences"
assert
len
(
seqs
)
==
1
,
(
"Beam search not supported in multi-step decoding."
)
seq
=
seqs
[
0
]
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples
=
[
outputs
[
step
].
samples
[
0
]
for
step
in
range
(
len
(
outputs
))]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
]
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
def
_process_seq_outputs
(
self
,
seq
:
Sequence
,
valid_samples
:
List
[
SequenceOutput
],
sampling_params
:
SamplingParams
)
->
None
:
output_token_ids
=
[
sample
.
output_token
for
sample
in
valid_samples
]
# Truncate to max_tokens if necessary.
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
len
(
output_token_ids
))
if
remaining_tokens
<
0
:
valid_samples
=
valid_samples
[:
remaining_tokens
]
output_token_ids
=
output_token_ids
[:
remaining_tokens
]
# Truncate any tokens after EOS. This is required as spec decode
# generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be
# unintentionally ignored.
if
not
sampling_params
.
ignore_eos
:
eos_token_id
=
self
.
get_tokenizer_for_seq
(
seq
).
eos_token_id
# Avoiding .index calls as exception throwing in the happy path
# is expensive.
for
i
in
range
(
len
(
output_token_ids
)):
if
output_token_ids
[
i
]
==
eos_token_id
:
output_token_ids
=
output_token_ids
[:
i
+
1
]
valid_samples
=
valid_samples
[:
i
+
1
]
break
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for
output_token_id
in
output_token_ids
:
seq
.
append_token_id
(
token_id
=
output_token_id
,
# TODO emit logprobs in multi-step decoding.
logprobs
=
{
output_token_id
:
Logprob
(
0.0
)},
)
new_char_count
=
0
if
sampling_params
.
detokenize
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
)
if
seq
.
is_finished
():
break
if
seq
.
is_finished
():
self
.
scheduler
.
free_seq
(
seq
)
vllm/engine/output_processor/single_step.py
0 → 100644
View file @
e95cd879
from
typing
import
Iterable
,
List
,
Tuple
,
Union
from
vllm.config
import
SchedulerConfig
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
logger
=
init_logger
(
__name__
)
class
SingleStepOutputProcessor
(
SequenceGroupOutputProcessor
):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
scheduling of the next batch. Output processing logic includes
detokenization, and determining if a sequence is finished (e.g. via max len
or eos token).
The SingleStepOutputProcessor is specialized to the case where the model
emits at most a single token per invocation, which precludes configurations
such as speculative decoding or multi-step decoding. This enables beam
search sampling, which requires forking/finishing/freeing sequences in a way
that is currently difficult to schedule multiple steps ahead of time.
"""
def
__init__
(
self
,
scheduler_config
:
SchedulerConfig
,
detokenizer
:
Detokenizer
,
scheduler
:
Scheduler
,
seq_counter
:
Iterable
[
int
],
stop_checker
:
StopChecker
,
):
self
.
scheduler_config
=
scheduler_config
self
.
detokenizer
=
detokenizer
self
.
scheduler
=
scheduler
self
.
seq_counter
=
seq_counter
self
.
stop_checker
=
stop_checker
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
"""
assert
(
len
(
outputs
)
==
1
),
f
"
{
type
(
self
)
}
does not support multiple outputs per step"
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
])
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
# Process prompt logprobs
prompt_logprobs
=
outputs
.
prompt_logprobs
if
prompt_logprobs
is
not
None
and
seq_group
.
sampling_params
.
detokenize
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
)
seq_group
.
prompt_logprobs
=
prompt_logprobs
# Process samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
existing_finished_seqs
=
seq_group
.
get_finished_seqs
()
parent_child_dict
=
{
parent_seq
.
seq_id
:
[]
for
parent_seq
in
parent_seqs
}
for
sample
in
samples
:
parent_child_dict
[
sample
.
parent_seq_id
].
append
(
sample
)
# List of (child, parent)
child_seqs
:
List
[
Tuple
[
Sequence
,
Sequence
]]
=
[]
# Process the child samples for each parent sequence
for
parent
in
parent_seqs
:
child_samples
:
List
[
SequenceOutput
]
=
parent_child_dict
[
parent
.
seq_id
]
if
len
(
child_samples
)
==
0
:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent
.
status
=
SequenceStatus
.
FINISHED_ABORTED
seq_group
.
remove
(
parent
.
seq_id
)
self
.
scheduler
.
free_seq
(
parent
)
continue
# Fork the parent sequence if there are multiple child samples.
for
child_sample
in
child_samples
[:
-
1
]:
new_child_seq_id
=
next
(
self
.
seq_counter
)
child
=
parent
.
fork
(
new_child_seq_id
)
child
.
append_token_id
(
child_sample
.
output_token
,
child_sample
.
logprobs
)
child_seqs
.
append
((
child
,
parent
))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample
=
child_samples
[
-
1
]
parent
.
append_token_id
(
last_child_sample
.
output_token
,
last_child_sample
.
logprobs
)
child_seqs
.
append
((
parent
,
parent
))
for
seq
,
_
in
child_seqs
:
if
seq_group
.
sampling_params
.
detokenize
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
seq_group
.
sampling_params
)
else
:
new_char_count
=
0
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
,
seq_group
.
sampling_params
)
# Non-beam search case
if
not
seq_group
.
sampling_params
.
use_beam_search
:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
child_seqs
:
if
seq
is
not
parent
:
seq_group
.
add
(
seq
)
if
not
seq
.
is_finished
():
self
.
scheduler
.
fork_seq
(
parent
,
seq
)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for
seq
,
parent
in
child_seqs
:
if
seq
is
parent
and
seq
.
is_finished
():
self
.
scheduler
.
free_seq
(
seq
)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs
=
[]
unselected_child_seqs
=
[]
beam_width
=
seq_group
.
sampling_params
.
best_of
length_penalty
=
seq_group
.
sampling_params
.
length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs
=
[(
seq
,
None
,
False
)
for
seq
in
existing_finished_seqs
]
new_finished_seqs
=
[(
seq
,
parent
,
True
)
for
seq
,
parent
in
child_seqs
if
seq
.
is_finished
()]
all_finished_seqs
=
existing_finished_seqs
+
new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
x
[
0
].
eos_token_id
),
reverse
=
True
)
for
seq
,
parent
,
is_new
in
all_finished_seqs
[:
beam_width
]:
if
is_new
:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs
.
append
((
seq
,
parent
))
for
seq
,
parent
,
is_new
in
all_finished_seqs
[
beam_width
:]:
if
is_new
:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs
.
append
((
seq
,
parent
))
else
:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group
.
remove
(
seq
.
seq_id
)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs
=
[(
seq
,
parent
)
for
seq
,
parent
in
child_seqs
if
not
seq
.
is_finished
()]
# Sort the running sequences by their scores.
running_child_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
x
[
0
].
eos_token_id
),
reverse
=
True
)
# Check if we can stop the beam search.
if
len
(
running_child_seqs
)
==
0
:
# No running sequences, stop the beam search.
stop_beam_search
=
True
elif
len
(
all_finished_seqs
)
<
beam_width
:
# Not enough finished sequences, continue the beam search.
stop_beam_search
=
False
else
:
# Check the early stopping criteria
best_running_seq
=
running_child_seqs
[
0
][
0
]
current_worst_seq
=
all_finished_seqs
[
beam_width
-
1
][
0
]
stop_beam_search
=
self
.
_check_beam_search_early_stopping
(
seq_group
.
sampling_params
.
early_stopping
,
seq_group
.
sampling_params
,
best_running_seq
,
current_worst_seq
)
if
stop_beam_search
:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs
.
extend
(
running_child_seqs
)
else
:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs
.
extend
(
running_child_seqs
[:
beam_width
])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs
.
extend
(
running_child_seqs
[
beam_width
:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
selected_child_seqs
:
if
seq
is
not
parent
:
seq_group
.
add
(
seq
)
if
not
seq
.
is_finished
():
self
.
scheduler
.
fork_seq
(
parent
,
seq
)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for
seq
,
parent
in
selected_child_seqs
:
if
seq
is
parent
and
seq
.
is_finished
():
self
.
scheduler
.
free_seq
(
seq
)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for
seq
,
parent
in
unselected_child_seqs
:
if
seq
is
parent
:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group
.
remove
(
seq
.
seq_id
)
self
.
scheduler
.
free_seq
(
seq
)
def
_check_beam_search_early_stopping
(
self
,
early_stopping
:
Union
[
bool
,
str
],
sampling_params
:
SamplingParams
,
best_running_seq
:
Sequence
,
current_worst_seq
:
Sequence
,
)
->
bool
:
assert
sampling_params
.
use_beam_search
length_penalty
=
sampling_params
.
length_penalty
if
early_stopping
is
True
:
return
True
current_worst_score
=
current_worst_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
current_worst_seq
.
eos_token_id
)
if
early_stopping
is
False
:
highest_attainable_score
=
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
best_running_seq
.
eos_token_id
)
else
:
assert
early_stopping
==
"never"
if
length_penalty
>
0.0
:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length
=
max
(
best_running_seq
.
get_prompt_len
()
+
sampling_params
.
max_tokens
,
self
.
scheduler_config
.
max_model_len
)
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
best_running_seq
.
eos_token_id
,
seq_len
=
max_possible_length
))
else
:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
best_running_seq
.
eos_token_id
))
return
current_worst_score
>=
highest_attainable_score
vllm/engine/output_processor/stop_checker.py
0 → 100644
View file @
e95cd879
from
typing
import
Callable
,
Optional
from
transformers
import
PreTrainedTokenizer
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceStatus
class
StopChecker
:
"""LLMEngine helper class which separates out the logic involving stop
checking. This checks things such as: whether the eos token was emitted,
whether the max_tokens has been consumed, whether a stop string has been
emitted, or if we have exceeded the max model len.
"""
def
__init__
(
self
,
max_model_len
:
int
,
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
PreTrainedTokenizer
]):
self
.
max_model_len
=
max_model_len
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
def
maybe_stop_sequence
(
self
,
seq
:
Sequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
)
->
None
:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
sampling_params
.
stop_token_ids
:
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop_str
=
self
.
_check_stop_strings
(
seq
,
new_char_count
,
sampling_params
)
if
stop_str
is
not
None
:
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
max_model_len
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
@
staticmethod
def
_check_stop_strings
(
seq
:
Sequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
)
->
Optional
[
str
]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
"""
if
not
new_char_count
:
return
None
for
stop_str
in
sampling_params
.
stop
:
stop_string_len
=
len
(
stop_str
)
# Avoid searching already-searched text.
stop_index
=
seq
.
output_text
.
find
(
stop_str
,
-
new_char_count
-
stop_string_len
)
if
stop_index
==
-
1
:
continue
if
sampling_params
.
include_stop_str_in_output
:
# Truncate to end of stop string.
stop_index
+=
stop_string_len
if
stop_index
>=
len
(
seq
.
output_text
):
# No truncation required.
return
stop_str
# Truncate the output text to either the beginning
# or end of the stop string.
seq
.
output_text
=
seq
.
output_text
[:
stop_index
]
return
stop_str
return
None
vllm/engine/output_processor/util.py
0 → 100644
View file @
e95cd879
from
typing
import
List
from
vllm.sequence
import
SamplerOutput
def
create_output_by_sequence_group
(
sampler_outputs
:
List
[
SamplerOutput
],
num_seq_groups
:
int
):
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group
=
[[]
for
_
in
range
(
num_seq_groups
)]
for
step
in
sampler_outputs
:
for
i
,
sequence_group_output
in
enumerate
(
step
):
output_by_sequence_group
[
i
].
append
(
sequence_group_output
)
return
output_by_sequence_group
vllm/executor/cpu_executor.py
View file @
e95cd879
...
...
@@ -74,7 +74,8 @@ class CPUExecutor(ExecutorBase):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
...
vllm/executor/executor_base.py
View file @
e95cd879
...
...
@@ -72,8 +72,9 @@ class ExecutorBase(ABC):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
"""Executes one model step on the given sequences."""
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
"""Executes at least one model step on the given sequences."""
raise
NotImplementedError
@
abstractmethod
...
...
vllm/executor/gpu_executor.py
View file @
e95cd879
...
...
@@ -13,13 +13,17 @@ logger = init_logger(__name__)
class
GPUExecutor
(
ExecutorBase
):
def
_init_executor
(
self
)
->
None
:
assert
(
not
self
.
speculative_config
),
"Speculative decoding not yet supported for GPU backend"
"""Initialize the worker and load the model.
# Instantiate the worker and load the model to GPU.
self
.
_init_worker
()
If speculative decoding is enabled, we instead create the speculative
worker.
"""
if
self
.
speculative_config
is
None
:
self
.
_init_non_spec_worker
()
else
:
self
.
_init_spec_worker
()
def
_init_worker
(
self
):
def
_init_
non_spec_
worker
(
self
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
...
...
@@ -46,6 +50,57 @@ class GPUExecutor(ExecutorBase):
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
def
_init_spec_worker
(
self
):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert
self
.
speculative_config
is
not
None
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
from
vllm.worker.worker
import
Worker
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
target_worker
=
Worker
(
model_config
=
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
device_config
=
self
.
device_config
,
cache_config
=
self
.
cache_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
is_driver_worker
=
True
,
)
draft_worker
=
MultiStepWorker
(
model_config
=
self
.
speculative_config
.
draft_model_config
,
parallel_config
=
self
.
speculative_config
.
draft_parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
device_config
=
self
.
device_config
,
cache_config
=
self
.
cache_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
is_driver_worker
=
True
,
)
spec_decode_worker
=
SpecDecodeWorker
.
from_workers
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
)
assert
self
.
parallel_config
.
world_size
==
1
,
(
"GPUExecutor only supports single GPU."
)
self
.
driver_worker
=
spec_decode_worker
# Load model handled in spec decode worker.
self
.
driver_worker
.
init_device
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
...
...
@@ -63,16 +118,20 @@ class GPUExecutor(ExecutorBase):
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
self
,
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
,
)
return
output
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment