Commit 7d867671 authored by lizhigong's avatar lizhigong
Browse files

fix llm_engine to zero_overhead

parent 08c2298a
...@@ -1233,6 +1233,27 @@ class LLMEngine: ...@@ -1233,6 +1233,27 @@ class LLMEngine:
return None return None
def _fix_last_step(
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group, token_id in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups, sample_out_list):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
continue
if seq_group_metadata.do_sample:
sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
sample.output_token = token_id[0]
seq.fix_last_token_id(sample.output_token)
def _advance_to_next_step( def _advance_to_next_step(
self, output: List[SamplerOutput], self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
...@@ -1386,9 +1407,14 @@ class LLMEngine: ...@@ -1386,9 +1407,14 @@ class LLMEngine:
assert scheduler_outputs is not None assert scheduler_outputs is not None
profile.ProfRangeAutoPush('execute_model') profile.ProfRangeAutoPush('execute_model')
last_outputs = None last_outputs_ids = None
last_outputs_tensor = None
if self.zero_overhead: if self.zero_overhead:
last_outputs = self.trans_last_output_tensor(self.output_recorder[self.step_switch]) recode_output = self.output_recorder[self.step_switch]
if recode_output is not None:
last_output = recode_output[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.output_recorder[self.step_switch] = None # only use for once
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Check if we have a cached last_output from the previous iteration. # Check if we have a cached last_output from the previous iteration.
...@@ -1398,14 +1424,6 @@ class LLMEngine: ...@@ -1398,14 +1424,6 @@ class LLMEngine:
last_sampled_token_ids = \ last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine) self._get_last_sampled_token_ids(virtual_engine)
# print('seq_group_metadata_list', len(seq_group_metadata_list))
# print('scheduler_outputs.blocks_to_swap_in', len(scheduler_outputs.blocks_to_swap_in))
# print('scheduler_outputs.num_lookahead_slots', scheduler_outputs.num_lookahead_slots)
# print('scheduler_outputs.running_queue_size', scheduler_outputs.running_queue_size)
# print('finished_requests_ids', len(finished_requests_ids))
# print('last_sampled_token_ids', last_sampled_token_ids)
# print('self.model_executor', type(self.model_executor))
execute_model_req = ExecuteModelRequest( execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
...@@ -1417,7 +1435,8 @@ class LLMEngine: ...@@ -1417,7 +1435,8 @@ class LLMEngine:
# We use ExecuteModelRequest to pass the last sampled_token_ids # We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input. # to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids, last_sampled_token_ids=last_sampled_token_ids,
last_outputs = last_outputs) last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
if allow_async_output_proc: if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
...@@ -1437,19 +1456,28 @@ class LLMEngine: ...@@ -1437,19 +1456,28 @@ class LLMEngine:
# No outputs in this case # No outputs in this case
outputs = [] outputs = []
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if self.zero_overhead: if self.zero_overhead:
self.output_recorder[self.step_switch] = outputs self.output_recorder[self.step_switch] = [outputs, seq_group_metadata_list, scheduler_outputs]
self._advance_to_next_step(
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
self.step_switch = 1 - self.step_switch self.step_switch = 1 - self.step_switch
outputs = self.output_recorder[self.step_switch]
if outputs is None: recode_output = self.output_recorder[self.step_switch]
if recode_output is None:
return None return None
#同步上一次的output outputs, seq_group_metadata_list, scheduler_outputs = self.output_recorder[self.step_switch]
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self._fix_last_step(
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list): if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps. # clear the cache if we have finished all the steps.
...@@ -1473,7 +1501,7 @@ class LLMEngine: ...@@ -1473,7 +1501,7 @@ class LLMEngine:
if outputs and allow_async_output_proc: if outputs and allow_async_output_proc:
assert len(outputs) == 1, ( assert len(outputs) == 1, (
"Async postprocessor expects only a single output set") "Async postprocessor expects only a single output set")
if not self.zero_overhead:
self._advance_to_next_step( self._advance_to_next_step(
outputs[0], seq_group_metadata_list, outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
...@@ -1505,7 +1533,6 @@ class LLMEngine: ...@@ -1505,7 +1533,6 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters. # queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.") logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs return ctx.request_outputs
def _has_remaining_steps( def _has_remaining_steps(
......
...@@ -1388,6 +1388,7 @@ class LLM: ...@@ -1388,6 +1388,7 @@ class LLM:
total_out_toks = 0 total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
print('###step_outputs', step_outputs)
if step_outputs is None: if step_outputs is None:
continue continue
for output in step_outputs: for output in step_outputs:
......
...@@ -22,3 +22,8 @@ def _update_input_tokens( ...@@ -22,3 +22,8 @@ def _update_input_tokens(
if _seq_ids == _input_seq_id: if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i) output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token) tl.store(input_tokens + pid, output_token)
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
last_ids = last_ids.to('cuda')
grid = [input_seq_ids.shape[0], 1, 1]
_update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools import itertools
import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
...@@ -72,11 +73,12 @@ class SampleResultArgsType: ...@@ -72,11 +73,12 @@ class SampleResultArgsType:
# Implemented by guanyu # Implemented by guanyu
@dataclass @dataclass
class SampleDeviceToDevices: class SampleDeviceToDevices:
num_parent_seq: torch.Tensor=None def __init__(self):
seq_id:torch.Tensor=None self.seq_id:torch.Tensor = None
random_samples:torch.Tensor=None self.random_samples:torch.Tensor = None
sample_idx:int=None self.zero_overhead:bool = False
d2d_data=SampleDeviceToDevices()
d2d_data = SampleDeviceToDevices()
# Union of non-deferred (single-step scheduling) # Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling) # vs deferred (multi-step scheduling)
...@@ -144,6 +146,9 @@ class SamplerOutput( ...@@ -144,6 +146,9 @@ class SamplerOutput(
# tree-style cartesian candidates # tree-style cartesian candidates
tree_attn_masks: Optional[torch.Tensor] = None tree_attn_masks: Optional[torch.Tensor] = None
sampler_out_tenosr : Optional[torch.Tensor] = None
sampler_out_ids : Optional[torch.Tensor] = None
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx] return self.outputs[idx]
...@@ -174,7 +179,10 @@ class SamplerOutput( ...@@ -174,7 +179,10 @@ class SamplerOutput(
f"sampled_token_ids={sampled_token_ids_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics}, " f"spec_decode_worker_metrics={self.spec_decode_worker_metrics}, "
f"logits={self.logits}, " f"logits={self.logits}, "
f"tree_attn_masks={self.tree_attn_masks})") f"tree_attn_masks={self.tree_attn_masks}, "
f"sampler_out_tenosr={self.sampler_out_tenosr}, "
f"sampler_out_ids={self.sampler_out_ids}, "
f")")
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -206,6 +214,8 @@ class Sampler(nn.Module): ...@@ -206,6 +214,8 @@ class Sampler(nn.Module):
# speculative decoding. # speculative decoding.
self.include_gpu_probs_tensor = False self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False self.should_modify_greedy_probs_inplace = False
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
d2d_data.zero_overhead = self.zero_overhead
def _init_sampling_tensors( def _init_sampling_tensors(
self, self,
...@@ -503,7 +513,7 @@ def _random_sample( ...@@ -503,7 +513,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
# Find the maximum n value of the prompt phase requests. # Find the maximum n value of the prompt phase requests.
#random_samples = random_samples.cpu()删除,取消gpu->cpu之间的同步 if not d2d_data.zero_overhead:
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -516,20 +526,24 @@ def _random_sample( ...@@ -516,20 +526,24 @@ def _random_sample(
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
d2d_data.num_parent_seq = num_parent_seqs
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
parent_ids = [0] * sampling_params.n parent_ids = [0] * sampling_params.n
if d2d_data.zero_overhead:
next_token_ids = [0] * sampling_params.n
else:
next_token_ids = random_samples[ next_token_ids = random_samples[
sample_idx, :sampling_params.n].tolist() sample_idx, :sampling_params.n].tolist()
else: else:
# Generation phase. # Generation phase.
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
next_token_ids = [0] * num_parent_seqs
else:
next_token_ids = random_samples[sample_idx:sample_idx + next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist() num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
d2d_data.sample_idx=sample_idx
return results return results
...@@ -707,7 +721,7 @@ def get_pythonized_sample_results( ...@@ -707,7 +721,7 @@ def get_pythonized_sample_results(
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples) sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
d2d_data.random_samples=multinomial_samples[sampling_type]#记录random_samples的数据 d2d_data.random_samples = multinomial_samples[sampling_type]#记录random_samples的数据
sample_results = _random_sample(seq_groups, sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type]) multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
...@@ -744,13 +758,11 @@ def _sample_with_torch( ...@@ -744,13 +758,11 @@ def _sample_with_torch(
categorized_seq_group_ids: Dict[SamplingType, List[int]] = { categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: [] t: []
for t in SamplingType for t in SamplingType
}#初始化各种结果存储容器然后按照类型分类 }
print(f'sampling_metadata.seq_groups的长度:{len(sampling_metadata.seq_groups)}') d2d_data.seq_id = torch.zeros(len(sampling_metadata.seq_groups))
# 初始化一个tensor张量用于保存seq_id,初始值为-1
d2d_data.seq_id=torch.zeros(len(sampling_metadata.seq_groups),1)-1
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
d2d_data.seq_id[i]=seq_group.seq_ids[0]#将 i对应的seq_id存储到d2d_data.seq_id中 d2d_data.seq_id[i] = seq_group.seq_ids[0]
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
...@@ -1280,13 +1292,18 @@ def _build_sampler_output( ...@@ -1280,13 +1292,18 @@ def _build_sampler_output(
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None) None)
if d2d_data.zero_overhead:
pass
return SamplerOutput( return SamplerOutput(
outputs=sampler_output, outputs=sampler_output,
sampled_token_probs=sampled_token_probs, sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor, logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args, deferred_sample_results_args=deferred_sample_results_args,
logits=logits) logits=logits,
sampler_out_tenosr = d2d_data.random_samples,
sampler_out_ids = d2d_data.seq_id)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
...@@ -582,6 +582,11 @@ class Sequence: ...@@ -582,6 +582,11 @@ class Sequence:
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob) self.data.append_token_id(token_id, logprobs[token_id].logprob)
def fix_last_token_id(self, token_id: int) -> None:
self.data._output_token_ids[-2] = token_id
self.data._new_appended_tokens[-2] = token_id
self.data._cached_all_token_ids[-2] = token_id
def get_len(self) -> int: def get_len(self) -> int:
return self.data.get_len() return self.data.get_len()
...@@ -1403,7 +1408,10 @@ class ExecuteModelRequest( ...@@ -1403,7 +1408,10 @@ class ExecuteModelRequest(
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
# for zero-overhead scheduler # for zero-overhead scheduler
last_outputs : Optional[torch.Tensor] = None last_outputs_sample : Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_ids : Optional[torch.Tensor] = None
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
...@@ -1455,7 +1463,8 @@ class ExecuteModelRequest( ...@@ -1455,7 +1463,8 @@ class ExecuteModelRequest(
tree_attn_masks=self.tree_attn_masks, tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids, tree_position_ids=self.tree_position_ids,
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved, kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved,
last_outputs = self.last_outputs) last_outputs_sample = self.last_outputs_sample,
last_outputs_ids = self.last_outputs_ids)
@dataclass @dataclass
......
...@@ -4,6 +4,7 @@ import dataclasses ...@@ -4,6 +4,7 @@ import dataclasses
import gc import gc
import inspect import inspect
import itertools import itertools
import os
import time import time
import weakref import weakref
from contextlib import contextmanager from contextlib import contextmanager
...@@ -59,6 +60,8 @@ from vllm.worker.model_runner_base import ( ...@@ -59,6 +60,8 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict)
from vllm.model_executor.layers.ops.update_input import UpdateInputTokens
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -476,6 +479,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -476,6 +479,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.sliding_window_blocks * self.block_size self.sliding_window_blocks * self.block_size
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
self.last_sample_tensor = None
self.last_sample_ids = None
self.req_ids = []
def SetLastSamperData(self, last_sample_ids, last_sample_tensor):
self.last_sample_tensor = last_sample_tensor
self.last_sample_ids = last_sample_ids
def prepare(self, def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None: finished_requests_ids: Optional[List[str]] = None) -> None:
...@@ -491,6 +502,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -491,6 +502,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
ModelInputForGPUBuilder.InterDataForSeqGroup] = [] ModelInputForGPUBuilder.InterDataForSeqGroup] = []
self.attn_metadata_builder.prepare() self.attn_metadata_builder.prepare()
self.req_ids.clear()
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
...@@ -756,8 +768,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -756,8 +768,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len=encoder_seq_len) encoder_seq_len=encoder_seq_len)
self.inter_data_list.append(inter_data) self.inter_data_list.append(inter_data)
seq_ids = list(seq_ids)
for seq_idx in range(n_seqs): for seq_idx in range(n_seqs):
self.req_ids.append(seq_ids[seq_idx])
for per_seq_fn in self.per_seq_compute_fns: for per_seq_fn in self.per_seq_compute_fns:
per_seq_fn(inter_data, seq_idx, seq_group_metadata) per_seq_fn(inter_data, seq_idx, seq_group_metadata)
for per_seq_group_fn in self.per_seq_group_compute_fns: for per_seq_group_fn in self.per_seq_group_compute_fns:
...@@ -898,10 +911,18 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -898,10 +911,18 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if cuda_graph_pad_size: if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
if self.zero_overhead and self.last_sample_tensor is not None:
input_ids = torch.tensor(self.req_ids, device='cuda')
UpdateInputTokens(input_tokens_tensor, input_ids, self.last_sample_tensor, self.last_sample_ids)
print('####input_tokens_tensor', input_tokens_tensor)
print('####input_ids', input_ids)
print('####self.last_sample_tensor', self.last_sample_tensor)
print('####self.last_sample_ids', self.last_sample_ids)
token_types_tensor = async_tensor_h2d(token_types, torch.long, token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device, self.runner.device,
...@@ -1200,7 +1221,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1200,7 +1221,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def _prepare_model_input_tensors( def _prepare_model_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> TModelInputForGPU: ) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence """Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not group. Prepares metadata needed for the base model forward pass but not
...@@ -1221,7 +1244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1221,7 +1244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.builder.add_seq_group(seq_group_metadata) self.builder.add_seq_group(seq_group_metadata)
self.builder.reset_cached_inter_data() self.builder.reset_cached_inter_data()
self.builder.SetLastSamperData(last_outputs_ids, last_output_sample)
return self.builder.build() # type: ignore return self.builder.build() # type: ignore
@contextmanager @contextmanager
...@@ -1616,6 +1639,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1616,6 +1639,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including """Prepare the model input based on a given sequence group, including
metadata for the sampling step. metadata for the sampling step.
...@@ -1631,7 +1656,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1631,7 +1656,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs. If cuda graph is required, this API automatically pads inputs.
""" """
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids) seq_group_metadata_list, finished_requests_ids, last_outputs_ids, last_output_sample)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group # Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids) generators = self.get_generators(finished_requests_ids)
......
...@@ -189,7 +189,6 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -189,7 +189,6 @@ class ModelRunnerBase(ABC, Generic[T]):
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.last_output = None
# Map of request_id -> generator used for seeded random sampling # Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {} generators: Dict[str, torch.Generator] = {}
...@@ -211,6 +210,8 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -211,6 +210,8 @@ class ModelRunnerBase(ABC, Generic[T]):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> T: ) -> T:
""" """
Prepare the inputs to ModelRunnerBase.execute_model from an execution Prepare the inputs to ModelRunnerBase.execute_model from an execution
......
...@@ -353,12 +353,13 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -353,12 +353,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input( worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
self.model_runner.last_output = execute_model_req.last_outputs
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine, execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids)) execute_model_req.finished_requests_ids,
last_outputs_ids = execute_model_req.last_outputs_ids,
last_output_sample = execute_model_req.last_outputs_sample))
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \ if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
execute_model_req.tree_attn_masks is not None: execute_model_req.tree_attn_masks is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment