Unverified Commit 1a014a0a authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Move MM encoder to Model States [3/N] (#35564)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 86ac7bcf
......@@ -89,7 +89,6 @@ class CudaGraphManager:
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
......@@ -116,9 +115,6 @@ class CudaGraphManager:
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
"inputs_embeds": (
inputs_embeds[:num_tokens] if inputs_embeds is not None else None
),
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
......@@ -255,7 +251,6 @@ class CudaGraphManager:
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
......@@ -267,7 +262,6 @@ class CudaGraphManager:
model=model,
model_state=model_state,
input_buffers=input_buffers,
inputs_embeds=inputs_embeds,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
......
......@@ -66,8 +66,6 @@ class InputBatch:
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None
# [total_num_logits]
logits_indices: torch.Tensor
......@@ -138,7 +136,6 @@ class InputBatch:
dcp_local_seq_lens=None,
input_ids=input_ids,
positions=positions,
inputs_embeds=None,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec
class EncoderCache:
def __init__(self):
# req_id -> MM features
self.mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
# MM hash -> encoder outputs
self.encoder_outputs: dict[str, torch.Tensor] = {}
def add_request(
self, req_id: str, mm_features: list[MultiModalFeatureSpec]
) -> None:
self.mm_features[req_id] = mm_features
def remove_request(self, req_id: str) -> None:
self.mm_features.pop(req_id, None)
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_outputs.clear()
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_outputs.pop(mm_hash, None)
......@@ -4,8 +4,9 @@ import numpy as np
import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
......@@ -14,44 +15,19 @@ class EncoderRunner:
self,
max_num_tokens: int,
hidden_size: int,
encoder_cache: EncoderCache,
dtype: torch.dtype,
device: torch.device,
):
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.encoder_cache = encoder_cache
self.dtype = dtype
self.device = device
self.inputs_embeds = torch.zeros(
max_num_tokens, hidden_size, dtype=dtype, device=device
)
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_cache.pop(mm_hash, None)
def remove_request(self, req_id: str) -> None:
self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs(
self, scheduled_encoder_inputs: dict[str, list[int]]
......@@ -59,7 +35,7 @@ class EncoderRunner:
mm_hashes: list[str] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id]
mm_features = self.encoder_cache.mm_features[req_id]
for mm_input_id in encoder_input_ids:
mm_feature = mm_features[mm_input_id]
if mm_feature.data is None:
......@@ -90,7 +66,7 @@ class EncoderRunner:
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs
def gather_mm_embeddings(
......@@ -122,7 +98,7 @@ class EncoderRunner:
# OPTIMIZATION: Skip decode requests.
continue
mm_features = self.req_id_to_mm_features[req_id]
mm_features = self.encoder_cache.mm_features[req_id]
for mm_feature in mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
......@@ -148,7 +124,7 @@ class EncoderRunner:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
encoder_output = self.encoder_cache.encoder_outputs.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
......
......@@ -77,7 +77,7 @@ from vllm.v1.worker.gpu.kv_connector import (
get_kv_connector,
)
from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.model_states import ModelState
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
......@@ -127,20 +127,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
if self.supports_mm_inputs:
self.encoder_runner = EncoderRunner(
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
dtype=self.dtype,
device=self.device,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device)
......@@ -162,6 +148,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
self.encoder_cache = None
if self.supports_mm_inputs and self.is_first_pp_rank:
self.encoder_cache = EncoderCache()
self.speculator = None
self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False
......@@ -272,7 +267,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model(self.speculator)
# Initialize the components that require the model.
self.model_state = ModelState(self.vllm_config, self.model, self.device)
self.model_state = ModelState(
self.vllm_config, self.model, self.encoder_cache, self.device
)
if self.is_pooling_model:
self.pooling_runner = PoolingRunner(self.model)
......@@ -435,12 +432,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
gc.collect()
def reset_mm_cache(self) -> None:
if self.supports_mm_inputs:
self.encoder_runner.reset_mm_cache()
if self.encoder_cache is not None:
self.encoder_cache.reset_mm_cache()
def reset_encoder_cache(self) -> None:
if self.supports_mm_inputs:
self.encoder_runner.reset_encoder_cache()
if self.encoder_cache is not None:
self.encoder_cache.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet.
......@@ -469,14 +466,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
inputs_embeds = None
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds
self.cudagraph_manager.capture(
model=self.model,
model_state=self.model_state,
input_buffers=self.input_buffers,
inputs_embeds=inputs_embeds,
block_tables=self.block_tables,
attn_groups=self.attn_groups,
kv_cache_config=self.kv_cache_config,
......@@ -511,15 +504,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
finished_req_ids = finished_req_ids.union(preempted_req_ids)
for req_id in finished_req_ids:
self.req_states.remove_request(req_id)
if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id)
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.supports_mm_inputs:
if self.encoder_cache is not None:
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_runner.free_encoder_cache(mm_hash)
self.encoder_cache.free_encoder_cache(mm_hash)
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
for new_req_data in scheduler_output.scheduled_new_reqs:
......@@ -535,8 +528,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
req_index = self.req_states.req_id_to_index[req_id]
if self.supports_mm_inputs:
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
if self.encoder_cache is not None:
self.encoder_cache.add_request(req_id, new_req_data.mm_features)
self.model_state.add_request(req_index, new_req_data)
self.block_tables.append_block_ids(
......@@ -695,7 +688,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dcp_local_seq_lens=dcp_local_seq_lens,
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding],
inputs_embeds=None,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
......@@ -724,26 +716,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return block_tables, slot_mappings
@torch.inference_mode()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
) -> tuple[list[torch.Tensor], torch.Tensor]:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
return mm_embeds, is_mm_embed
def sample(
self,
hidden_states: torch.Tensor,
......@@ -890,18 +862,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
# Only first PP rank prepares multimodal embeddings.
if self.supports_mm_inputs and self.is_first_pp_rank:
mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
)
input_batch.inputs_embeds = inputs_embeds[
: input_batch.num_tokens_after_padding
]
else:
# No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
......@@ -934,10 +894,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config,
)
inputs_embeds = None
if self.supports_mm_inputs and self.is_first_pp_rank and not dummy_run:
# Run MM encoder (if needed) and get multimodal embeddings.
# Only first PP rank prepares multimodal embeddings.
inputs_embeds = self.model_state.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs,
input_batch,
self.req_states,
)
model_inputs = {
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
"inputs_embeds": input_batch.inputs_embeds,
"inputs_embeds": inputs_embeds,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**self.model_state.prepare_inputs(input_batch, self.req_states),
......
......@@ -10,22 +10,43 @@ from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class ModelState:
def __init__(self, vllm_config: VllmConfig, model: nn.Module, device: torch.device):
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = vllm_config.scheduler_config
self.model = model
self.device = device
self.supports_mm_inputs = encoder_cache is not None
self.max_model_len = self.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
self.dtype = self.model_config.dtype
if self.supports_mm_inputs:
assert encoder_cache is not None
self.encoder_runner = EncoderRunner(
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
encoder_cache=encoder_cache,
dtype=self.dtype,
device=self.device,
)
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
......@@ -51,6 +72,29 @@ class ModelState:
if self.uses_mrope:
self.mrope_state.apply_staged_writes()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
req_states.prefill_len.np[input_batch.idx_mapping_np],
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
)
return inputs_embeds[: input_batch.num_tokens_after_padding]
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
......@@ -73,10 +117,14 @@ class ModelState:
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
if not self.uses_mrope:
return {}
model_inputs = {}
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
model_inputs["inputs_embeds"] = inputs_embeds
if self.uses_mrope:
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
return {"positions": mrope_positions}
model_inputs["positions"] = mrope_positions
return model_inputs
def prepare_attn(
self,
......
......@@ -44,7 +44,6 @@ class EagleSpeculator:
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment