Unverified Commit e09ce759 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[TPU] Remove multi-modal args in TPU backend (#6504)

parent 5fa6e987
import time import time
from typing import List, Mapping, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -12,8 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ...@@ -12,8 +12,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata, SamplerOutput, SequenceGroupMetadata,
SequenceOutput) SequenceOutput)
...@@ -68,10 +66,6 @@ class TPUModelRunner: ...@@ -68,10 +66,6 @@ class TPUModelRunner:
False, False,
) )
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
def load_model(self) -> None: def load_model(self) -> None:
self.device = self.device_config.device self.device = self.device_config.device
...@@ -154,7 +148,7 @@ class TPUModelRunner: ...@@ -154,7 +148,7 @@ class TPUModelRunner:
# Dummy run. # Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, kv_caches, attn_metadata, self.model(token_ids, position_ids, kv_caches, attn_metadata,
input_lens, None, t, p, num_samples) input_lens, t, p, num_samples)
def warmup_model( def warmup_model(
self, self,
...@@ -199,14 +193,12 @@ class TPUModelRunner: ...@@ -199,14 +193,12 @@ class TPUModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
prompt_lens: List[int] = [] prompt_lens: List[int] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
...@@ -232,11 +224,6 @@ class TPUModelRunner: ...@@ -232,11 +224,6 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot) slot_mapping[-1].append(slot)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
assert len(prompt_lens) > 0 assert len(prompt_lens) > 0
num_prefills = len(prompt_lens) num_prefills = len(prompt_lens)
num_prefill_tokens = sum(prompt_lens) num_prefill_tokens = sum(prompt_lens)
...@@ -274,24 +261,17 @@ class TPUModelRunner: ...@@ -274,24 +261,17 @@ class TPUModelRunner:
block_tables=None, block_tables=None,
context_lens=None, context_lens=None,
) )
return input_tokens, input_positions, attn_metadata, prompt_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_kwargs)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
context_lens: List[int] = [] context_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
batch_idx = 0 batch_idx = 0
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
...@@ -317,11 +297,6 @@ class TPUModelRunner: ...@@ -317,11 +297,6 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append([slot]) slot_mapping.append([slot])
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
batch_size = _get_padded_batch_size(batch_idx) batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings input_tokens = input_tokens + [[0]] * num_paddings
...@@ -355,12 +330,7 @@ class TPUModelRunner: ...@@ -355,12 +330,7 @@ class TPUModelRunner:
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )
return input_tokens, input_positions, attn_metadata, input_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, input_lens,
multi_modal_kwargs)
def _prepare_sample( def _prepare_sample(
self, self,
...@@ -513,7 +483,6 @@ class ModelWrapper(nn.Module): ...@@ -513,7 +483,6 @@ class ModelWrapper(nn.Module):
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
input_lens: torch.Tensor, input_lens: torch.Tensor,
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
t: torch.Tensor, t: torch.Tensor,
p: torch.Tensor, p: torch.Tensor,
num_samples: int, num_samples: int,
...@@ -527,8 +496,6 @@ class ModelWrapper(nn.Module): ...@@ -527,8 +496,6 @@ class ModelWrapper(nn.Module):
memory profiling at initialization. memory profiling at initialization.
attn_metadata: The Pallas attention metadata. attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size]. input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size]. t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size]. p: The top-p probability of shape [batch_size].
""" """
...@@ -573,7 +540,6 @@ class ModelWrapper(nn.Module): ...@@ -573,7 +540,6 @@ class ModelWrapper(nn.Module):
position_ids, position_ids,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
**(multi_modal_kwargs or {}),
) )
hidden_states = hidden_states.flatten(0, 1) hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment