"vllm/vscode:/vscode.git/clone" did not exist on "d95d0f4b985f28ea381e301490f9d479b34d8980"
Unverified Commit f5722a50 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[V1] Scatter and gather placeholders in the model runner (#15712)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 651cf0fe
...@@ -19,7 +19,8 @@ from vllm.config import VllmConfig ...@@ -19,7 +19,8 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -36,7 +37,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler ...@@ -36,7 +37,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from .utils import sanity_check_mm_encoder_outputs from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -507,19 +509,47 @@ class TPUModelRunner: ...@@ -507,19 +509,47 @@ class TPUModelRunner:
logits_indices = logits_indices.to(self.device) logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices return attn_metadata, logits_indices
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _scatter_placeholders(
self,
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def _gather_placeholders(
self,
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return placeholders
return placeholders[is_embed]
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = [] mm_inputs = list[MultiModalKwargs]()
req_input_ids: list[tuple[str, int]] = [] req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for input_id in encoder_input_ids: for input_id, pos_info in zip(
encoder_input_ids,
req_state.mm_positions,
):
mm_inputs.append(req_state.mm_inputs[input_id]) mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id)) req_ids_pos.append((req_id, input_id, pos_info))
# Batch mm inputs as much as we can: if a request in the batch has # Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one, # multiple modalities or a different modality than the previous one,
...@@ -555,16 +585,23 @@ class TPUModelRunner: ...@@ -555,16 +585,23 @@ class TPUModelRunner:
encoder_outputs.append(output) encoder_outputs.append(output)
# Cache the encoder outputs. # Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache: if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {} self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs( self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
def _gather_mm_embeddings(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = [] mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
...@@ -572,8 +609,8 @@ class TPUModelRunner: ...@@ -572,8 +609,8 @@ class TPUModelRunner:
num_computed_tokens = req_state.num_computed_tokens num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions): for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"] start_pos = pos_info.offset
num_encoder_tokens = pos_info["length"] num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap: # The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, # [num_computed_tokens,
...@@ -595,8 +632,16 @@ class TPUModelRunner: ...@@ -595,8 +632,16 @@ class TPUModelRunner:
assert req_id in self.encoder_cache assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id] assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i] encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
@torch.no_grad() @torch.no_grad()
def execute_model( def execute_model(
...@@ -612,10 +657,10 @@ class TPUModelRunner: ...@@ -612,10 +657,10 @@ class TPUModelRunner:
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
else: else:
encoder_outputs = [] mm_embeds = []
# Prepare inputs # Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
...@@ -623,9 +668,9 @@ class TPUModelRunner: ...@@ -623,9 +668,9 @@ class TPUModelRunner:
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
if encoder_outputs: if mm_embeds:
inputs_embeds = self.model.get_input_embeddings( inputs_embeds = self.model.get_input_embeddings(
self.input_ids, encoder_outputs) self.input_ids, mm_embeds)
else: else:
inputs_embeds = self.model.get_input_embeddings(self.input_ids) inputs_embeds = self.model.get_input_embeddings(self.input_ids)
input_ids = None input_ids = None
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch import torch
...@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs( ...@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation " "instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.") "of the model's `get_multimodal_embeddings` method.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if is_embed is None:
return placeholders
return placeholders[is_embed]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment