Unverified Commit 002800f0 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Align vLLM's beam search implementation with HF generate (#857)

parent e15932bb
...@@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following ...@@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ kv_caches: List[KVCache], + kv_caches: List[KVCache],
+ input_metadata: InputMetadata, + input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]], + cache_events: Optional[List[torch.cuda.Event]],
+) -> Dict[int, SequenceOutputs]: +) -> SamplerOutput:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. 3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture. 4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
......
...@@ -67,8 +67,8 @@ class HfRunner: ...@@ -67,8 +67,8 @@ class HfRunner:
output_ids, output_ids,
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
)[0] )
output_ids = output_ids[0].cpu().tolist() output_ids = output_ids.cpu().tolist()
outputs.append((output_ids, output_str)) outputs.append((output_ids, output_str))
return outputs return outputs
...@@ -77,8 +77,34 @@ class HfRunner: ...@@ -77,8 +77,34 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
return self.generate(prompts, do_sample=False, outputs = self.generate(prompts,
max_new_tokens=max_tokens) do_sample=False,
max_new_tokens=max_tokens)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
return outputs
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
num_beams=beam_width,
num_return_sequences=beam_width)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
for j in range(len(output_ids)):
output_ids[j] = [
x for x in output_ids[j]
if x != self.tokenizer.pad_token_id
]
outputs[i] = (output_ids, output_str)
return outputs
@pytest.fixture @pytest.fixture
...@@ -107,15 +133,20 @@ class VllmRunner: ...@@ -107,15 +133,20 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
req_outputs = self.model.generate( req_outputs = self.model.generate(prompts,
prompts, sampling_params=sampling_params) sampling_params=sampling_params)
outputs = [] outputs = []
for req_output in req_outputs: for req_output in req_outputs:
prompt_str = req_output.prompt prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids prompt_ids = req_output.prompt_token_ids
output_str = req_output.outputs[0].text req_sample_output_ids = []
output_ids = req_output.outputs[0].token_ids req_sample_output_strs = []
outputs.append((prompt_ids + output_ids, prompt_str + output_str)) for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs return outputs
def generate_greedy( def generate_greedy(
...@@ -124,7 +155,22 @@ class VllmRunner: ...@@ -124,7 +155,22 @@ class VllmRunner:
max_tokens: int, max_tokens: int,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
return self.generate(prompts, greedy_params) outputs = self.generate(prompts, greedy_params)
return [(output_ids[0], output_str[0]) for output_ids, output_str in
outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens)
outputs = self.generate(prompts, beam_search_params)
return outputs
@pytest.fixture @pytest.fixture
......
"""Compare the outputs of HF and vLLM when using beam search.
Run `pytest tests/samplers/test_beam_search.py --forked`.
"""
import pytest
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [128]
BEAM_WIDTHS = [4]
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_single_input(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")
...@@ -172,9 +172,7 @@ class BlockSpaceManager: ...@@ -172,9 +172,7 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block. # CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
if seq.is_finished():
continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
...@@ -203,9 +201,7 @@ class BlockSpaceManager: ...@@ -203,9 +201,7 @@ class BlockSpaceManager:
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block. # GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.is_finished():
continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
......
...@@ -7,8 +7,7 @@ from vllm.core.block_manager import BlockSpaceManager ...@@ -7,8 +7,7 @@ from vllm.core.block_manager import BlockSpaceManager
from vllm.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs, SequenceGroupMetadata, SequenceStatus)
SequenceStatus)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -76,6 +75,7 @@ class Scheduler: ...@@ -76,6 +75,7 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
) )
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = [] self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state. # Sequence groups in the RUNNING state.
...@@ -96,10 +96,11 @@ class Scheduler: ...@@ -96,10 +96,11 @@ class Scheduler:
if seq_group.request_id in request_ids: if seq_group.request_id in request_ids:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(seq_group) state_queue.remove(seq_group)
for seq in seq_group.seqs: for seq in seq_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED) seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
request_ids.remove(seq_group.request_id) request_ids.remove(seq_group.request_id)
if not request_ids: if not request_ids:
return return
...@@ -123,6 +124,10 @@ class Scheduler: ...@@ -123,6 +124,10 @@ class Scheduler:
if not self.swapped: if not self.swapped:
ignored_seq_groups: List[SequenceGroup] = [] ignored_seq_groups: List[SequenceGroup] = []
scheduled: List[SequenceGroup] = [] scheduled: List[SequenceGroup] = []
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
num_batched_tokens = 0 num_batched_tokens = 0
# Optimization: We do not sort the waiting queue since the preempted # Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups # sequence groups are added to the front and the new sequence groups
...@@ -130,6 +135,9 @@ class Scheduler: ...@@ -130,6 +135,9 @@ class Scheduler:
while self.waiting: while self.waiting:
seq_group = self.waiting[0] seq_group = self.waiting[0]
assert seq_group.num_seqs() == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len() num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens > self.prompt_limit: if num_prompt_tokens > self.prompt_limit:
logger.warning( logger.warning(
...@@ -152,11 +160,7 @@ class Scheduler: ...@@ -152,11 +160,7 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs( num_new_seqs = seq_group.get_max_num_running_seqs()
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs > if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
...@@ -165,6 +169,7 @@ class Scheduler: ...@@ -165,6 +169,7 @@ class Scheduler:
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens num_batched_tokens += num_prompt_tokens
num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(seq_group)
if scheduled: if scheduled:
...@@ -210,30 +215,32 @@ class Scheduler: ...@@ -210,30 +215,32 @@ class Scheduler:
# Swap in the sequence groups in the SWAPPED state if possible. # Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped) self.swapped = self.policy.sort_by_priority(now, self.swapped)
while self.swapped and not blocks_to_swap_out: if not preempted:
seq_group = self.swapped[0] num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
# If the sequence group has been preempted in this step, stop. for seq_group in self.running)
if seq_group in preempted:
break while self.swapped:
# If the sequence group cannot be swapped in, stop. seq_group = self.swapped[0]
if not self.block_manager.can_swap_in(seq_group): # If the sequence group cannot be swapped in, stop.
break if not self.block_manager.can_swap_in(seq_group):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
self.running.append(seq_group)
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
self.running.append(seq_group)
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
num_batched_tokens = sum( num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running) for seq_group in self.running)
...@@ -275,40 +282,10 @@ class Scheduler: ...@@ -275,40 +282,10 @@ class Scheduler:
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
def update( def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self, self.block_manager.fork(parent_seq, child_seq)
seq_outputs: Dict[int, SequenceOutputs],
) -> List[SequenceGroup]:
scheduled: List[SequenceGroup] = []
for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.seq_id in seq_outputs:
scheduled.append(seq_group)
break
# Update the scheduled sequences and free blocks.
for seq_group in scheduled:
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam
# search). Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
return scheduled
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None: def free_seq(self, seq: Sequence) -> None:
seq.status = finish_status
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
...@@ -345,8 +322,8 @@ class Scheduler: ...@@ -345,8 +322,8 @@ class Scheduler:
# If preemption mode is not specified, we determine the mode as follows: # If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than # We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences # swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not supported. In such a case, # (e.g., beam search), recomputation is not currently supported. In
# we use swapping instead. # such a case, we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre. # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences, # As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized # sequence groups with multiple sequences are implicitly prioritized
...@@ -354,8 +331,7 @@ class Scheduler: ...@@ -354,8 +331,7 @@ class Scheduler:
# TODO(woosuk): Support recomputation for sequence groups with multiple # TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel. # sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None: if preemption_mode is None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) if seq_group.get_max_num_running_seqs() == 1:
if len(seqs) == 1:
preemption_mode = PreemptionMode.RECOMPUTE preemption_mode = PreemptionMode.RECOMPUTE
else: else:
preemption_mode = PreemptionMode.SWAP preemption_mode = PreemptionMode.SWAP
......
...@@ -11,7 +11,8 @@ from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray ...@@ -11,7 +11,8 @@ from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata, from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus) SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
...@@ -258,14 +259,11 @@ class LLMEngine: ...@@ -258,14 +259,11 @@ class LLMEngine:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seqs: List[Sequence] = [] seq_id = next(self.seq_counter)
for _ in range(sampling_params.best_of): seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time) arrival_time)
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler.
...@@ -303,22 +301,230 @@ class LLMEngine: ...@@ -303,22 +301,230 @@ class LLMEngine:
] ]
return seq_group_metadata_list, scheduler_outputs, None return seq_group_metadata_list, scheduler_outputs, None
def _process_worker_outputs( def _check_beam_search_early_stopping(
self, output, self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_samples(
self, seq_group: SequenceGroup,
samples: List[SequenceOutputs]) -> None:
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutputs] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self._decode_sequence(seq)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _process_model_outputs(
self, output: SamplerOutput,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduler with the model outputs. # Update the scheduled sequence groups with the model outputs.
seq_groups = self.scheduler.update(output) scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, samples in zip(scheduled_seq_groups, output):
self._process_sequence_group_samples(seq_group, samples)
# Decode the sequences.
self._decode_sequences(seq_groups)
# Stop the sequences that meet the stopping criteria.
self._stop_sequences(seq_groups)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups: for seq_group in (scheduled_seq_groups +
scheduler_outputs.ignored_seq_groups):
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
...@@ -351,7 +557,7 @@ class LLMEngine: ...@@ -351,7 +557,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
) )
return self._process_worker_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
def _log_system_stats( def _log_system_stats(
self, self,
...@@ -416,55 +622,44 @@ class LLMEngine: ...@@ -416,55 +622,44 @@ class LLMEngine:
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now self.last_logging_time = now
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _decode_sequence(self, seq: Sequence) -> None:
"""Decodes the sequence outputs.""" """Decodes the new token for a sequence."""
for seq_group in seq_groups: new_token, new_output_text = detokenize_incrementally(
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): self.tokenizer,
new_token, new_output_text = detokenize_incrementally( seq.output_tokens,
self.tokenizer, seq.get_last_token_id(),
seq.output_tokens, skip_special_tokens=True,
seq.get_last_token_id(), )
skip_special_tokens=True, if new_token is not None:
) seq.output_tokens.append(new_token)
if new_token is not None: seq.output_text = new_output_text
seq.output_tokens.append(new_token)
seq.output_text = new_output_text def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Stop the finished sequences.""" """Stop the finished sequences."""
for seq_group in seq_groups: for stop_str in sampling_params.stop:
sampling_params = seq_group.sampling_params if seq.output_text.endswith(stop_str):
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): # Truncate the output text so that the stop string is
# Check if the sequence has generated a stop string. # not included in the output.
stopped = False seq.output_text = seq.output_text[:-len(stop_str)]
for stop_str in sampling_params.stop: seq.status = SequenceStatus.FINISHED_STOPPED
if seq.output_text.endswith(stop_str): return
# Truncate the output text so that the stop string is
# not included in the output. # Check if the sequence has reached max_model_len.
seq.output_text = seq.output_text[:-len(stop_str)] if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.free_seq( seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
seq, SequenceStatus.FINISHED_STOPPED) return
stopped = True
break # Check if the sequence has reached max_tokens.
if stopped: if seq.get_output_len() == sampling_params.max_tokens:
continue seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len: # Check if the sequence has generated the EOS token.
self.scheduler.free_seq( if ((not sampling_params.ignore_eos)
seq, SequenceStatus.FINISHED_LENGTH_CAPPED) and seq.get_last_token_id() == self.tokenizer.eos_token_id):
continue seq.status = SequenceStatus.FINISHED_STOPPED
# Check if the sequence has reached max_tokens. return
if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
continue
def _run_workers( def _run_workers(
self, self,
......
...@@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata ...@@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region) gather_from_tensor_model_parallel_region)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput, SequenceOutputs
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -39,7 +39,7 @@ class Sampler(nn.Module): ...@@ -39,7 +39,7 @@ class Sampler(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata) hidden_states = _prune_hidden_states(hidden_states, input_metadata)
...@@ -292,7 +292,13 @@ def _sample_from_prompt( ...@@ -292,7 +292,13 @@ def _sample_from_prompt(
if sampling_params.use_beam_search: if sampling_params.use_beam_search:
# Beam search. # Beam search.
beam_width = sampling_params.best_of beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width) # Sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
_, next_token_ids = torch.topk(prob, 2 * beam_width)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature < _SAMPLING_EPS: elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling. # Greedy sampling.
...@@ -330,29 +336,11 @@ def _sample_from_generation_tokens( ...@@ -330,29 +336,11 @@ def _sample_from_generation_tokens(
vocab_size = logprobs.size(-1) vocab_size = logprobs.size(-1)
beam_width = len(seq_ids) beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width) _, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
topk_ids = topk_ids.tolist() topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids] seq_idx = [i // vocab_size for i in topk_ids]
beam_seq_ids = [seq_ids[i] for i in seq_idx] parent_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = [i % vocab_size for i in topk_ids] next_token_ids = [i % vocab_size for i in topk_ids]
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []
# If a beam survives, continue with it.
for seq_id, token_id in zip(beam_seq_ids, token_ids):
if seq_id not in beam_outputs:
beam_outputs[seq_id] = (seq_id, token_id)
else:
outstanding_beams.append((seq_id, token_id))
# If a beam is discarded, fork another beam.
for seq_id in seq_ids:
if seq_id not in beam_outputs:
beam_outputs[seq_id] = outstanding_beams.pop()
assert not outstanding_beams
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
elif sampling_params.temperature < _SAMPLING_EPS: elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling. # Greedy sampling.
assert len(seq_ids) == 1 assert len(seq_ids) == 1
...@@ -374,16 +362,18 @@ def _sample( ...@@ -374,16 +362,18 @@ def _sample(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
seq_outputs: Dict[int, SequenceOutputs] = {} seq_outputs: SamplerOutput = []
# TODO(woosuk): Optimize. # TODO(woosuk): Optimize.
idx = 0 idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_group_outputs: List[SequenceOutputs] = []
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts: if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input. # Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.best_of assert len(seq_ids) == 1, "Prompt input should have only one seq."
parent_seq_id = seq_ids[0]
prob = probs[idx] prob = probs[idx]
logprob = logprobs[idx] logprob = logprobs[idx]
idx += 1 idx += 1
...@@ -395,17 +385,18 @@ def _sample( ...@@ -395,17 +385,18 @@ def _sample(
sampling_params.logprobs) sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids): for next_token_id in next_token_ids:
output_logprobs = next_logprobs.copy() output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item() output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id, seq_group_outputs.append(
next_token_id, SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs) output_logprobs))
else: else:
# Generate the next tokens for generation tokens. # Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)] num_parent_seqs = len(seq_ids)
logprob = logprobs[idx:idx + len(seq_ids)] prob = probs[idx:idx + num_parent_seqs]
idx += len(seq_ids) logprob = logprobs[idx:idx + num_parent_seqs]
idx += num_parent_seqs
# Sample the next tokens. # Sample the next tokens.
seq_logprobs = [ seq_logprobs = [
...@@ -422,17 +413,15 @@ def _sample( ...@@ -422,17 +413,15 @@ def _sample(
logprob[j], sampling_params.logprobs) logprob[j], sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, parent_seq_id, next_token_id in zip( for parent_seq_id, next_token_id in zip(parent_seq_ids,
seq_ids, parent_seq_ids, next_token_ids): next_token_ids):
j = seq_ids.index(parent_seq_id) j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy() output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[j, output_logprobs[next_token_id] = logprob[j,
next_token_id].item() next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_group_outputs.append(
seq_id, SequenceOutputs(parent_seq_id, next_token_id,
parent_seq_id, output_logprobs))
next_token_id, seq_outputs.append(seq_group_outputs)
output_logprobs,
)
return seq_outputs return seq_outputs
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -41,7 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -41,7 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig from vllm.transformers_utils.configs.aquila import AquilaConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -273,7 +273,7 @@ class AquilaForCausalLM(nn.Module): ...@@ -273,7 +273,7 @@ class AquilaForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
......
...@@ -23,12 +23,11 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses ...@@ -23,12 +23,11 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
import math import math
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -42,6 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -42,6 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -290,7 +290,7 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -290,7 +290,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
......
...@@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses ...@@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
import math import math
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module): ...@@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
reduce_from_tensor_model_parallel_region) reduce_from_tensor_model_parallel_region)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -397,7 +397,7 @@ class FalconForCausalLM(nn.Module): ...@@ -397,7 +397,7 @@ class FalconForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
positions, positions,
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -218,7 +218,7 @@ class GPT2LMHeadModel(nn.Module): ...@@ -218,7 +218,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -39,7 +39,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -39,7 +39,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -246,7 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -246,7 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -203,7 +203,7 @@ class GPTJForCausalLM(nn.Module): ...@@ -203,7 +203,7 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.embed_out.weight, hidden_states, next_tokens = self.sampler(self.embed_out.weight, hidden_states,
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -17,7 +17,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import ( ...@@ -17,7 +17,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
from vllm.model_executor.weight_utils import ( from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab, hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -218,7 +218,7 @@ class InternLMForCausalLM(nn.Module): ...@@ -218,7 +218,7 @@ class InternLMForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -43,7 +43,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -43,7 +43,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -256,7 +256,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -256,7 +256,7 @@ class LlamaForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
......
# coding=utf-8 # coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -16,7 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -16,7 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -230,7 +230,7 @@ class MPTForCausalLM(nn.Module): ...@@ -230,7 +230,7 @@ class MPTForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module): ...@@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -32,7 +32,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import ( ...@@ -32,7 +32,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig from vllm.transformers_utils.configs.qwen import QWenConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -235,7 +235,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -235,7 +235,7 @@ class QWenLMHeadModel(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment