Unverified Commit 1a014a0a authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Move MM encoder to Model States [3/N] (#35564)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 86ac7bcf
...@@ -89,7 +89,6 @@ class CudaGraphManager: ...@@ -89,7 +89,6 @@ class CudaGraphManager:
model: nn.Module, model: nn.Module,
model_state: ModelState, model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
...@@ -116,9 +115,6 @@ class CudaGraphManager: ...@@ -116,9 +115,6 @@ class CudaGraphManager:
model_inputs = { model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens], "input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens], "positions": input_buffers.positions[:num_tokens],
"inputs_embeds": (
inputs_embeds[:num_tokens] if inputs_embeds is not None else None
),
# NOTE: Values returned by `prepare_dummy_inputs` will override the # NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above. # default values above.
**model_state.prepare_dummy_inputs(num_reqs, num_tokens), **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
...@@ -255,7 +251,6 @@ class CudaGraphManager: ...@@ -255,7 +251,6 @@ class CudaGraphManager:
model: nn.Module, model: nn.Module,
model_state: ModelState, model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
...@@ -267,7 +262,6 @@ class CudaGraphManager: ...@@ -267,7 +262,6 @@ class CudaGraphManager:
model=model, model=model,
model_state=model_state, model_state=model_state,
input_buffers=input_buffers, input_buffers=input_buffers,
inputs_embeds=inputs_embeds,
block_tables=block_tables, block_tables=block_tables,
attn_groups=attn_groups, attn_groups=attn_groups,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
......
...@@ -66,8 +66,6 @@ class InputBatch: ...@@ -66,8 +66,6 @@ class InputBatch:
input_ids: torch.Tensor input_ids: torch.Tensor
# [num_tokens_after_padding] # [num_tokens_after_padding]
positions: torch.Tensor positions: torch.Tensor
# [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None
# [total_num_logits] # [total_num_logits]
logits_indices: torch.Tensor logits_indices: torch.Tensor
...@@ -138,7 +136,6 @@ class InputBatch: ...@@ -138,7 +136,6 @@ class InputBatch:
dcp_local_seq_lens=None, dcp_local_seq_lens=None,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
inputs_embeds=None,
logits_indices=logits_indices, logits_indices=logits_indices,
cu_num_logits=cu_num_logits, cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np, cu_num_logits_np=cu_num_logits_np,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec
class EncoderCache:
def __init__(self):
# req_id -> MM features
self.mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
# MM hash -> encoder outputs
self.encoder_outputs: dict[str, torch.Tensor] = {}
def add_request(
self, req_id: str, mm_features: list[MultiModalFeatureSpec]
) -> None:
self.mm_features[req_id] = mm_features
def remove_request(self, req_id: str) -> None:
self.mm_features.pop(req_id, None)
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_outputs.clear()
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_outputs.pop(mm_hash, None)
...@@ -4,8 +4,9 @@ import numpy as np ...@@ -4,8 +4,9 @@ import numpy as np
import torch import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
...@@ -14,44 +15,19 @@ class EncoderRunner: ...@@ -14,44 +15,19 @@ class EncoderRunner:
self, self,
max_num_tokens: int, max_num_tokens: int,
hidden_size: int, hidden_size: int,
encoder_cache: EncoderCache,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
): ):
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.encoder_cache = encoder_cache
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.inputs_embeds = torch.zeros( self.inputs_embeds = torch.zeros(
max_num_tokens, hidden_size, dtype=dtype, device=device max_num_tokens, hidden_size, dtype=dtype, device=device
) )
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_cache.pop(mm_hash, None)
def remove_request(self, req_id: str) -> None:
self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs( def prepare_mm_inputs(
self, scheduled_encoder_inputs: dict[str, list[int]] self, scheduled_encoder_inputs: dict[str, list[int]]
...@@ -59,7 +35,7 @@ class EncoderRunner: ...@@ -59,7 +35,7 @@ class EncoderRunner:
mm_hashes: list[str] = [] mm_hashes: list[str] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = [] mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id] mm_features = self.encoder_cache.mm_features[req_id]
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_feature = mm_features[mm_input_id] mm_feature = mm_features[mm_input_id]
if mm_feature.data is None: if mm_feature.data is None:
...@@ -90,7 +66,7 @@ class EncoderRunner: ...@@ -90,7 +66,7 @@ class EncoderRunner:
encoder_outputs.extend(curr_group_outputs) encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash # Cache the encoder outputs by mm_hash
self.encoder_cache.update(zip(mm_hashes, encoder_outputs)) self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs return encoder_outputs
def gather_mm_embeddings( def gather_mm_embeddings(
...@@ -122,7 +98,7 @@ class EncoderRunner: ...@@ -122,7 +98,7 @@ class EncoderRunner:
# OPTIMIZATION: Skip decode requests. # OPTIMIZATION: Skip decode requests.
continue continue
mm_features = self.req_id_to_mm_features[req_id] mm_features = self.encoder_cache.mm_features[req_id]
for mm_feature in mm_features: for mm_feature in mm_features:
pos_info = mm_feature.mm_position pos_info = mm_feature.mm_position
start_pos = pos_info.offset start_pos = pos_info.offset
...@@ -148,7 +124,7 @@ class EncoderRunner: ...@@ -148,7 +124,7 @@ class EncoderRunner:
continue continue
mm_hash = mm_feature.identifier mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None) encoder_output = self.encoder_cache.encoder_outputs.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None: if (is_embed := pos_info.is_embed) is not None:
......
...@@ -77,7 +77,7 @@ from vllm.v1.worker.gpu.kv_connector import ( ...@@ -77,7 +77,7 @@ from vllm.v1.worker.gpu.kv_connector import (
get_kv_connector, get_kv_connector,
) )
from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.model_states import ModelState from vllm.v1.worker.gpu.model_states import ModelState
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
...@@ -127,20 +127,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -127,20 +127,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
if self.supports_mm_inputs:
self.encoder_runner = EncoderRunner(
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
dtype=self.dtype,
device=self.device,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device) self.output_copy_stream = torch.cuda.Stream(self.device)
...@@ -162,6 +148,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -162,6 +148,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0 self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
self.encoder_cache = None
if self.supports_mm_inputs and self.is_first_pp_rank:
self.encoder_cache = EncoderCache()
self.speculator = None self.speculator = None
self.num_speculative_steps = 0 self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
...@@ -272,7 +267,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -272,7 +267,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model(self.speculator) prepare_communication_buffer_for_model(self.speculator)
# Initialize the components that require the model. # Initialize the components that require the model.
self.model_state = ModelState(self.vllm_config, self.model, self.device) self.model_state = ModelState(
self.vllm_config, self.model, self.encoder_cache, self.device
)
if self.is_pooling_model: if self.is_pooling_model:
self.pooling_runner = PoolingRunner(self.model) self.pooling_runner = PoolingRunner(self.model)
...@@ -435,12 +432,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -435,12 +432,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
gc.collect() gc.collect()
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
if self.supports_mm_inputs: if self.encoder_cache is not None:
self.encoder_runner.reset_mm_cache() self.encoder_cache.reset_mm_cache()
def reset_encoder_cache(self) -> None: def reset_encoder_cache(self) -> None:
if self.supports_mm_inputs: if self.encoder_cache is not None:
self.encoder_runner.reset_encoder_cache() self.encoder_cache.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet. # SP is not supported yet.
...@@ -469,14 +466,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -469,14 +466,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config): with self.maybe_setup_dummy_loras(self.lora_config):
inputs_embeds = None
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds
self.cudagraph_manager.capture( self.cudagraph_manager.capture(
model=self.model, model=self.model,
model_state=self.model_state, model_state=self.model_state,
input_buffers=self.input_buffers, input_buffers=self.input_buffers,
inputs_embeds=inputs_embeds,
block_tables=self.block_tables, block_tables=self.block_tables,
attn_groups=self.attn_groups, attn_groups=self.attn_groups,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
...@@ -511,15 +504,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -511,15 +504,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
finished_req_ids = finished_req_ids.union(preempted_req_ids) finished_req_ids = finished_req_ids.union(preempted_req_ids)
for req_id in finished_req_ids: for req_id in finished_req_ids:
self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
if self.supports_mm_inputs: if self.encoder_cache is not None:
self.encoder_runner.remove_request(req_id) self.encoder_cache.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id) self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id) self.lora_state.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None: def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.supports_mm_inputs: if self.encoder_cache is not None:
for mm_hash in scheduler_output.free_encoder_mm_hashes: for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_runner.free_encoder_cache(mm_hash) self.encoder_cache.free_encoder_cache(mm_hash)
def add_requests(self, scheduler_output: SchedulerOutput) -> None: def add_requests(self, scheduler_output: SchedulerOutput) -> None:
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
...@@ -535,8 +528,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -535,8 +528,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
req_index = self.req_states.req_id_to_index[req_id] req_index = self.req_states.req_id_to_index[req_id]
if self.supports_mm_inputs: if self.encoder_cache is not None:
self.encoder_runner.add_request(req_id, new_req_data.mm_features) self.encoder_cache.add_request(req_id, new_req_data.mm_features)
self.model_state.add_request(req_index, new_req_data) self.model_state.add_request(req_index, new_req_data)
self.block_tables.append_block_ids( self.block_tables.append_block_ids(
...@@ -695,7 +688,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -695,7 +688,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dcp_local_seq_lens=dcp_local_seq_lens, dcp_local_seq_lens=dcp_local_seq_lens,
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding], input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding], positions=self.input_buffers.positions[:num_tokens_after_padding],
inputs_embeds=None,
logits_indices=logits_indices, logits_indices=logits_indices,
cu_num_logits=cu_num_logits, cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np, cu_num_logits_np=cu_num_logits_np,
...@@ -724,26 +716,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -724,26 +716,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
return block_tables, slot_mappings return block_tables, slot_mappings
@torch.inference_mode()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
) -> tuple[list[torch.Tensor], torch.Tensor]:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
return mm_embeds, is_mm_embed
def sample( def sample(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -890,18 +862,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -890,18 +862,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch.num_scheduled_tokens, input_batch.num_scheduled_tokens,
) )
self._set_active_loras(*lora_inputs) self._set_active_loras(*lora_inputs)
# Only first PP rank prepares multimodal embeddings.
if self.supports_mm_inputs and self.is_first_pp_rank:
mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
)
input_batch.inputs_embeds = inputs_embeds[
: input_batch.num_tokens_after_padding
]
else: else:
# No actual tokens to run. A dummy run for DP or memory profiling. # No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs) num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
...@@ -934,10 +894,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -934,10 +894,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config, self.kv_cache_config,
) )
inputs_embeds = None
if self.supports_mm_inputs and self.is_first_pp_rank and not dummy_run:
# Run MM encoder (if needed) and get multimodal embeddings.
# Only first PP rank prepares multimodal embeddings.
inputs_embeds = self.model_state.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs,
input_batch,
self.req_states,
)
model_inputs = { model_inputs = {
"input_ids": input_batch.input_ids, "input_ids": input_batch.input_ids,
"positions": input_batch.positions, "positions": input_batch.positions,
"inputs_embeds": input_batch.inputs_embeds, "inputs_embeds": inputs_embeds,
# NOTE: Values returned by `prepare_inputs` will override the default # NOTE: Values returned by `prepare_inputs` will override the default
# values above. # values above.
**self.model_state.prepare_inputs(input_batch, self.req_states), **self.model_state.prepare_inputs(input_batch, self.req_states),
......
...@@ -10,22 +10,43 @@ from vllm.v1.core.sched.output import NewRequestData ...@@ -10,22 +10,43 @@ from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
class ModelState: class ModelState:
def __init__(self, vllm_config: VllmConfig, model: nn.Module, device: torch.device): def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.model = model self.model = model
self.device = device self.device = device
self.supports_mm_inputs = encoder_cache is not None
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
self.dtype = self.model_config.dtype
if self.supports_mm_inputs:
assert encoder_cache is not None
self.encoder_runner = EncoderRunner(
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
encoder_cache=encoder_cache,
dtype=self.dtype,
device=self.device,
)
self.uses_mrope = self.model_config.uses_mrope self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope: if self.uses_mrope:
...@@ -51,6 +72,29 @@ class ModelState: ...@@ -51,6 +72,29 @@ class ModelState:
if self.uses_mrope: if self.uses_mrope:
self.mrope_state.apply_staged_writes() self.mrope_state.apply_staged_writes()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
req_states.prefill_len.np[input_batch.idx_mapping_np],
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
)
return inputs_embeds[: input_batch.num_tokens_after_padding]
def prepare_inputs( def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]: ) -> dict[str, torch.Tensor | None]:
...@@ -73,10 +117,14 @@ class ModelState: ...@@ -73,10 +117,14 @@ class ModelState:
def prepare_dummy_inputs( def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]: ) -> dict[str, torch.Tensor | None]:
if not self.uses_mrope: model_inputs = {}
return {} if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
model_inputs["inputs_embeds"] = inputs_embeds
if self.uses_mrope:
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
return {"positions": mrope_positions} model_inputs["positions"] = mrope_positions
return model_inputs
def prepare_attn( def prepare_attn(
self, self,
......
...@@ -44,7 +44,6 @@ class EagleSpeculator: ...@@ -44,7 +44,6 @@ class EagleSpeculator:
# the draft model's hidden size can be different from the target model's # the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B). # hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size() self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment