Unverified Commit da543d1a authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor refactoring for EncoderRunner (#35628)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 87d319c5
...@@ -13,12 +13,14 @@ from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs ...@@ -13,12 +13,14 @@ from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
class EncoderRunner: class EncoderRunner:
def __init__( def __init__(
self, self,
model: SupportsMultiModal,
max_num_tokens: int, max_num_tokens: int,
hidden_size: int, hidden_size: int,
encoder_cache: EncoderCache, encoder_cache: EncoderCache,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
): ):
self.model = model
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.encoder_cache = encoder_cache self.encoder_cache = encoder_cache
...@@ -48,25 +50,17 @@ class EncoderRunner: ...@@ -48,25 +50,17 @@ class EncoderRunner:
@torch.inference_mode() @torch.inference_mode()
def execute_mm_encoder( def execute_mm_encoder(
self, self,
model: SupportsMultiModal,
mm_hashes: list[str],
mm_kwargs: list[tuple[str, MultiModalKwargsItem]], mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
if not mm_hashes:
return []
encoder_outputs: list[torch.Tensor] = [] encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, device=self.device, pin_memory=False mm_kwargs, device=self.device, pin_memory=False
): ):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, expected_num_items=num_items curr_group_outputs, expected_num_items=num_items
) )
encoder_outputs.extend(curr_group_outputs) encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs return encoder_outputs
def gather_mm_embeddings( def gather_mm_embeddings(
...@@ -146,12 +140,11 @@ class EncoderRunner: ...@@ -146,12 +140,11 @@ class EncoderRunner:
@torch.inference_mode() @torch.inference_mode()
def get_inputs_embeds( def get_inputs_embeds(
self, self,
model: SupportsMultiModal,
input_ids: torch.Tensor, input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor], mm_embeds: list[torch.Tensor],
is_mm_embed: torch.Tensor, is_mm_embed: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
x = model.embed_input_ids( x = self.model.embed_input_ids(
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
) )
# Copy to the pre-allocated buffer for CUDA graphs. # Copy to the pre-allocated buffer for CUDA graphs.
......
...@@ -41,7 +41,9 @@ class DefaultModelState(ModelState): ...@@ -41,7 +41,9 @@ class DefaultModelState(ModelState):
if self.supports_mm_inputs: if self.supports_mm_inputs:
assert encoder_cache is not None assert encoder_cache is not None
self.encoder_cache = encoder_cache
self.encoder_runner = EncoderRunner( self.encoder_runner = EncoderRunner(
model=self.model,
max_num_tokens=self.max_num_tokens, max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size, hidden_size=self.inputs_embeds_size,
encoder_cache=encoder_cache, encoder_cache=encoder_cache,
...@@ -82,7 +84,12 @@ class DefaultModelState(ModelState): ...@@ -82,7 +84,12 @@ class DefaultModelState(ModelState):
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs( mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs scheduled_encoder_inputs
) )
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs) if mm_kwargs:
# Execute the multimodal encoder.
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings( mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids, input_batch.req_ids,
input_batch.num_tokens, input_batch.num_tokens,
...@@ -92,7 +99,7 @@ class DefaultModelState(ModelState): ...@@ -92,7 +99,7 @@ class DefaultModelState(ModelState):
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np], req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
) )
inputs_embeds = self.encoder_runner.get_inputs_embeds( inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed input_batch.input_ids, mm_embeds, is_mm_embed
) )
return inputs_embeds[: input_batch.num_tokens_after_padding] return inputs_embeds[: input_batch.num_tokens_after_padding]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment