import torch import itertools from typing import List, Optional, Set from vllm.lora.layers import LoRAMapping from vllm.multimodal.inputs import MultiModalKwargs from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import SequenceGroupMetadata from vllm.utils import async_tensor_h2d, flatten_2d_lists from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder from vllm.zero_overhead.sampler import get_last_sampler from vllm.zero_overhead.update_input import UpdateInputTokens class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder): def __init__(self, runner, finished_requests_ids = None): super().__init__(runner, finished_requests_ids) self.req_ids = [] def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: self.req_ids.clear() return super().prepare(finished_requests_ids) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): seq_ids = seq_group_metadata.seq_data.keys() n_seqs = len(seq_ids) seq_ids = list(seq_ids) for seq_idx in range(n_seqs): self.req_ids.append(seq_ids[seq_idx]) return super().add_seq_group(seq_group_metadata) def build(self) -> ModelInputForGPU: model_input = super().build() last_sampler = get_last_sampler() if last_sampler is not None: update_indices = [] select_indices = [] for i, seq_id in enumerate(self.req_ids): for j, seq_id_ in enumerate(last_sampler.seq_ids): if seq_id == seq_id_: select_indices.append(j) update_indices.append(i) break select_indices = async_tensor_h2d(select_indices, torch.long, self.runner.device, self.runner.pin_memory) update_indices = async_tensor_h2d(update_indices, torch.long, self.runner.device, self.runner.pin_memory) if len(select_indices) > 0: model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0] return model_input