"deploy/chrek/cmd/vscode:/vscode.git/clone" did not exist on "bb8fc8a4a969357000caf57c79af47df6b2e2113"
Unverified Commit e95cd879 authored by Cade Daniel's avatar Cade Daniel Committed by GitHub
Browse files

[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)

parent 69e1d2fb
...@@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, ...@@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[
{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 2,
"max_num_seqs": 2,
},
])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [
{
"use_v2_block_manager": False,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"use_v2_block_manager": True,
"num_lookahead_slots": 0,
},
{
"use_v2_block_manager": True,
"num_lookahead_slots": 5,
},
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify that chunked prefill works with BlockManagerV2, with and without
lookahead scheduling.
"""
output_len = 32
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
print('Getting token ids with BlockManagerV1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids with BlockManagerV2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator: for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True) outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
......
import time import time
from typing import Optional, Tuple from typing import Iterable, Optional, Tuple
from vllm import SamplingParams from vllm import SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -31,14 +31,17 @@ def create_dummy_prompt( ...@@ -31,14 +31,17 @@ def create_dummy_prompt(
def create_seq_group( def create_seq_group(
seq_prompt_len=1024, seq_prompt_len: int = 1024,
seq_output_lens=(128, ), seq_output_lens: Iterable[int] = (128, ),
request_id='0', request_id: str = '0',
seq_id_start=0, seq_id_start: int = 0,
) -> SequenceGroup: sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
assert len(seq_output_lens) > 0 assert len(seq_output_lens) > 0
if sampling_params is None:
sampling_params = SamplingParams()
prompt_token_ids = [0] * seq_prompt_len prompt_token_ids = [0] * seq_prompt_len
seqs = [] seqs = []
...@@ -60,7 +63,7 @@ def create_seq_group( ...@@ -60,7 +63,7 @@ def create_seq_group(
seq_group = SequenceGroup( seq_group = SequenceGroup(
request_id=request_id, request_id=request_id,
seqs=seqs, seqs=seqs,
sampling_params=SamplingParams(), sampling_params=sampling_params,
arrival_time=time.time(), arrival_time=time.time(),
) )
......
import random
from unittest.mock import MagicMock
import pytest
from transformers import PreTrainedTokenizer
from tests.core.utils import create_seq_group
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [1, 12])
@pytest.mark.skip_global_cleanup
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
"""Verify multi-step decoding appends token ids correctly.
We append token ids and verify all the token ids were appended correctly.
Note that ignore_eos=True.
"""
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
)
seq_group = create_seq_group(
seq_prompt_len=1024,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(max_tokens=seq_output_len +
num_new_tokens,
ignore_eos=True),
)
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
new_token_ids = list(range(num_new_tokens))
outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]
assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
output_processor.process_outputs(seq_group, outputs)
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
@pytest.mark.parametrize("max_tokens", [128 + 3])
@pytest.mark.skip_global_cleanup
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, max_tokens: int):
"""Verify tokens after max_tokens are dropped and not appended to the
sequence.
"""
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
)
seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(max_tokens=max_tokens, ),
)
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
new_token_ids = list(range(num_new_tokens))
outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]
assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)
# Expect the processed sequence to not go over max tokens in len.
assert seq.get_len() == seq_prompt_len + max_tokens
# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens
@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [12])
@pytest.mark.parametrize("seed", list(range(6)))
@pytest.mark.skip_global_cleanup
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, seed: int):
"""Verify the eos token id is included in the sequence, but subsequent
tokens are dropped (not appended to sequence).
"""
random.seed(seed)
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()
eos_token_id = 100
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
)
seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(
# Ensure enough space.
max_tokens=seq_output_len + num_new_tokens, ),
)
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
new_token_ids = list(range(num_new_tokens))
assert eos_token_id not in new_token_ids
eos_index = random.randint(0, len(new_token_ids) - 1)
new_token_ids[eos_index] = eos_token_id
outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]
assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)
# Expect the processed sequence to not go beyond provided eos.
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:eos_index + 1]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens
@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [12])
@pytest.mark.parametrize("seed", list(range(6)))
@pytest.mark.skip_global_cleanup
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, seed: int):
"""When sampling parameters dictate that we should ignore the eos token id,
ensure all token ids are appended even if the eos token id is emitted.
"""
random.seed(seed)
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()
eos_token_id = 100
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
)
seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(
# Ensure enough space.
max_tokens=seq_output_len + num_new_tokens,
ignore_eos=True,
),
)
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
new_token_ids = list(range(num_new_tokens))
assert eos_token_id not in new_token_ids
eos_index = random.randint(0, len(new_token_ids) - 1)
new_token_ids[eos_index] = eos_token_id
outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]
assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)
# Expect the processed sequence to go beyond eos.
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
seq_output_len]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens
def mock_tokenizer(eos_token_id=1000):
tokenizer = MagicMock(spec=PreTrainedTokenizer)
tokenizer.eos_token_id = eos_token_id
return tokenizer
from itertools import cycle
from typing import List, Tuple
import pytest import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
...@@ -7,18 +11,47 @@ from vllm import SamplingParams ...@@ -7,18 +11,47 @@ from vllm import SamplingParams
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
# Use a small model for a fast test. # Use a small model for a fast test.
"model": "facebook/opt-125m", # Note this is repeated in the test body; to initialize a tokenizer.
"speculative_model": "facebook/opt-125m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Skip real loading for fast test.
"load_format": "dummy",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode. # Required for spec decode.
"use_v2_block_manager": True "use_v2_block_manager": True
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 1,
},
{
# No spec decode.
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1])
# NOTE: We should run more permutations of this test (more BS, more seeds). But
# because our spec decode generates gibberish token ids, the likelihood of
# emitting an invalid token combination is nontrivial. This causes divergence in
# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
# start" bytes are emitted.
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_spec_decode_config(test_llm_generator): def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
output_len = 1024 """Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
"""
output_len = 32
temperature = 0.0 temperature = 0.0
prompts = [ prompts = [
...@@ -28,23 +61,91 @@ def test_spec_decode_config(test_llm_generator): ...@@ -28,23 +61,91 @@ def test_spec_decode_config(test_llm_generator):
"The future of AI is", "The future of AI is",
] ]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
batch_tokens, batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
# Expect a generation for each prompt in the batch.
assert len(batch_token_ids) == len(prompts)
# Expect each generation to have expected number of tokens (note
# ignore_eos=True).
assert all(len(token_ids) == output_len for token_ids in batch_token_ids)
# Expect detokenized string to match.
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
expected_tokens = tok.decode(actual_token_ids)
print(f"{actual_token_ids=}")
assert actual_tokens.strip() == expected_tokens.strip()
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Skip real loading for fast test.
"load_format": "dummy",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams( sampling_params = SamplingParams(
max_tokens=output_len, max_tokens=output_len,
ignore_eos=True, ignore_eos=True,
temperature=temperature, temperature=temperature,
) )
with pytest.raises( with pytest.raises(AssertionError,
AssertionError, match="Speculative decoding not yet supported for "):
match="Speculative decoding not yet supported for GPU backend"): get_output_from_llm_generator(test_llm_generator, prompts,
get_token_ids_from_llm_generator(test_llm_generator, prompts,
sampling_params) sampling_params)
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
for llm in llm_generator: for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True) outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs] token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
del llm del llm
return token_ids return tokens, token_ids
...@@ -125,7 +125,7 @@ def test_same_output_for_single_step(): ...@@ -125,7 +125,7 @@ def test_same_output_for_single_step():
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
expected_output = worker.execute_model( expected_output = worker.execute_model(
**single_step_execute_model_data.to_dict(), ) **single_step_execute_model_data.to_dict(), )[0]
actual_token_ids = [ actual_token_ids = [
output.samples[0].output_token for output in actual_output output.samples[0].output_token for output in actual_output
...@@ -219,7 +219,7 @@ def test_same_output_for_multi_step(): ...@@ -219,7 +219,7 @@ def test_same_output_for_multi_step():
continuations=continuations, continuations=continuations,
final_seq_lens=final_seq_lens)) final_seq_lens=final_seq_lens))
single_step_output.append( single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), )) worker.execute_model(**execute_model_data.to_dict(), ))
# Append output tokens to new sequence data. # Append output tokens to new sequence data.
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
...@@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): ...@@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
execute_model_data, _, _ = create_batch(batch_size, k) execute_model_data, _, _ = create_batch(batch_size, k)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
call_args_list = draft_worker.get_spec_proposals.call_args_list call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
...@@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker.execute_model.side_effect = ValueError(exception_secret) target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
seen_contexts = [] seen_contexts = []
...@@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs)
target_worker.execute_model.return_value = target_output[0] target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artifical stop' exception_secret = 'artifical stop'
rejection_sampler.side_effect = ValueError(exception_secret) rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
assert len(rejection_sampler.call_args_list) == 1 assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0] args, _ = rejection_sampler.call_args_list[0]
...@@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs)
target_worker.execute_model.return_value = target_output[0] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, rejection_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
...@@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler.return_value = rejection_sampler_output rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_lookahead_slots=k)
expected_output = create_sampler_output_list( expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
...@@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs)
target_worker.execute_model.return_value = target_output[0] target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0, rejection_sampler_output = torch.randint(low=0,
high=vocab_size, high=vocab_size,
...@@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
mock_rejsample_metrics) mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_lookahead_slots=k)
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = ( call_args_list = (
...@@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
...@@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int):
batch_size, k, prev_output_token_len=0) batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(), out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].probs is None, "expect gpu tensor references to be None"
...@@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int):
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with( draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False) **execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with( target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict()) **execute_model_data.to_dict())
...@@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64 rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
draft_worker.device = 'cuda' draft_worker.device = 'cuda'
target_worker.device = 'cuda' target_worker.device = 'cuda'
...@@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int):
batch_size, k, prev_output_token_len=0) batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(), out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k) num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].probs is None, "expect gpu tensor references to be None"
...@@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int):
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with( draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False) **execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with( target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict()) **execute_model_data.to_dict())
......
...@@ -212,7 +212,7 @@ def create_sampler_output_list( ...@@ -212,7 +212,7 @@ def create_sampler_output_list(
SequenceOutput( SequenceOutput(
output_token=token_id, output_token=token_id,
parent_seq_id=seq_ids[seq_index], parent_seq_id=seq_ids[seq_index],
logprobs={token_id: 0}, logprobs={token_id: Logprob(0)},
) )
], ],
prompt_logprobs=None, prompt_logprobs=None,
......
...@@ -104,7 +104,6 @@ class BlockTable: ...@@ -104,7 +104,6 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended. token_ids (List[int]): The sequence of token IDs to be appended.
""" """
assert self._is_allocated assert self._is_allocated
assert token_ids, "can't append empty token ids"
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots) num_lookahead_slots)
......
...@@ -762,9 +762,7 @@ class Scheduler: ...@@ -762,9 +762,7 @@ class Scheduler:
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy), swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups, ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=(prefills.num_lookahead_slots + num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_scheduled.num_lookahead_slots +
swapped_in.num_lookahead_slots),
) )
def _schedule_chunked_prefill(self): def _schedule_chunked_prefill(self):
...@@ -850,9 +848,7 @@ class Scheduler: ...@@ -850,9 +848,7 @@ class Scheduler:
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy), swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups, ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=(prefills.num_lookahead_slots + num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_scheduled.num_lookahead_slots +
swapped_in.num_lookahead_slots),
) )
def _schedule(self) -> SchedulerOutputs: def _schedule(self) -> SchedulerOutputs:
......
...@@ -217,7 +217,9 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -217,7 +217,9 @@ class _AsyncLLMEngine(LLMEngine):
else: else:
output = [] output = []
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
async def encode_request_async( async def encode_request_async(
self, self,
......
This diff is collapsed.
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from transformers import PreTrainedTokenizer
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
class SequenceGroupOutputProcessor(ABC):
"""Interface for logic that processes new token ids in sequence groups,
managing detokenization, stop checking, and freeing/forking sequences with
the scheduler.
This is highly coupled with the LLMEngine and should be seen as an extension
of it. The logic is separated to simplify the LLMEngine class and allow
separate implementations for single-step decoding (which supports beam
search sequence forking) and multi-step decoding (which does not support
beam search, but does support speculative decoding).
"""
@staticmethod
def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if scheduler_config.num_lookahead_slots == 0:
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
return SingleStepOutputProcessor(
scheduler_config,
detokenizer,
scheduler,
seq_counter,
stop_checker,
)
else:
# Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import (
MultiStepOutputProcessor)
return MultiStepOutputProcessor(
detokenizer,
scheduler,
seq_counter,
get_tokenizer_for_seq,
stop_checker,
)
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
"""
pass
from typing import Callable, Iterable, List
from transformers import PreTrainedTokenizer
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
logger = init_logger(__name__)
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles logic related to
detokenization and stopping conditions. It specializes to "multi-step
decoding", where vLLM's worker may generate multiple tokens per invocation.
This is currently mutually exclusive with advanced sampling techniques like
beam search, which motivates the separation of this logic from the single
step output processor.
This class is responsible for things such as correctly appending all new
token ids to their sequence, detokenizing new token ids, truncating new
output tokens after an eos token, and correctly handling the case where the
number of new output tokens per sequence differs in a single batch.
"""
def __init__(
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization,
including freeing finished sequences. It also handles cases where there
are tokens emitted after the EOS token.
"""
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
assert seqs, "expected running sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples]
# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode
# generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be
# unintentionally ignored.
if not sampling_params.ignore_eos:
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
# Avoiding .index calls as exception throwing in the happy path
# is expensive.
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id in output_token_ids:
seq.append_token_id(
token_id=output_token_id,
# TODO emit logprobs in multi-step decoding.
logprobs={output_token_id: Logprob(0.0)},
)
new_char_count = 0
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params)
if seq.is_finished():
break
if seq.is_finished():
self.scheduler.free_seq(seq)
from typing import Iterable, List, Tuple, Union
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
logger = init_logger(__name__)
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
scheduling of the next batch. Output processing logic includes
detokenization, and determining if a sequence is finished (e.g. via max len
or eos token).
The SingleStepOutputProcessor is specialized to the case where the model
emits at most a single token per invocation, which precludes configurations
such as speculative decoding or multi-step decoding. This enables beam
search sampling, which requires forking/finishing/freeing sequences in a way
that is currently difficult to schedule multiple steps ahead of time.
"""
def __init__(
self,
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
"""
assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0])
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score
from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
class StopChecker:
"""LLMEngine helper class which separates out the logic involving stop
checking. This checks things such as: whether the eos token was emitted,
whether the max_tokens has been consumed, whether a stop string has been
emitted, or if we have exceeded the max model len.
"""
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]):
self.max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
"""
if not new_char_count:
return None
for stop_str in sampling_params.stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
if stop_index == -1:
continue
if sampling_params.include_stop_str_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
# No truncation required.
return stop_str
# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return None
from typing import List
from vllm.sequence import SamplerOutput
def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
num_seq_groups: int):
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group = [[] for _ in range(num_seq_groups)]
for step in sampler_outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
return output_by_sequence_group
...@@ -74,7 +74,8 @@ class CPUExecutor(ExecutorBase): ...@@ -74,7 +74,8 @@ class CPUExecutor(ExecutorBase):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
output = self.driver_worker.execute_model( output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
......
...@@ -72,8 +72,9 @@ class ExecutorBase(ABC): ...@@ -72,8 +72,9 @@ class ExecutorBase(ABC):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]],
"""Executes one model step on the given sequences.""" num_lookahead_slots: int) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
...@@ -13,13 +13,17 @@ logger = init_logger(__name__) ...@@ -13,13 +13,17 @@ logger = init_logger(__name__)
class GPUExecutor(ExecutorBase): class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert (not self.speculative_config """Initialize the worker and load the model.
), "Speculative decoding not yet supported for GPU backend"
# Instantiate the worker and load the model to GPU. If speculative decoding is enabled, we instead create the speculative
self._init_worker() worker.
"""
if self.speculative_config is None:
self._init_non_spec_worker()
else:
self._init_spec_worker()
def _init_worker(self): def _init_non_spec_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -46,6 +50,57 @@ class GPUExecutor(ExecutorBase): ...@@ -46,6 +50,57 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.worker.worker import Worker
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
target_worker = Worker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
)
draft_worker = MultiStepWorker(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
)
spec_decode_worker = SpecDecodeWorker.from_workers(
proposer_worker=draft_worker, scorer_worker=target_worker)
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = spec_decode_worker
# Load model handled in spec decode worker.
self.driver_worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
...@@ -63,16 +118,20 @@ class GPUExecutor(ExecutorBase): ...@@ -63,16 +118,20 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self, def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = self.driver_worker.execute_model( output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots,
) )
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment