"vscode:/vscode.git/clone" did not exist on "de24046fcd24e8faa81de34b17351887bcdfbe51"
Unverified Commit 64172a97 authored by xwjiang2010's avatar xwjiang2010 Committed by GitHub
Browse files

[Feature] Add vision language model support. (#3042)

parent f408d05c
...@@ -29,6 +29,8 @@ _MODELS = { ...@@ -29,6 +29,8 @@ _MODELS = {
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
......
...@@ -250,14 +250,21 @@ class LlamaModel(nn.Module): ...@@ -250,14 +250,21 @@ class LlamaModel(nn.Module):
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
......
from typing import List, Optional, Tuple
import torch
from torch import nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def _merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
inputs_embeds[mask] = vision_embeddings.view(-1,
vision_embeddings.shape[-1])
class LlavaForConditionalGeneration(nn.Module):
def __init__(self,
config: "LlavaConfig",
vision_language_config: VisionLanguageConfig,
linear_method: Optional["LinearMethodBase"] = None) -> None:
super().__init__()
self.config = config
self.vision_language_config = vision_language_config
assert self.vision_language_config, (
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
else:
self.vision_tower = None
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.linear_method = linear_method
self.language_model = LlamaModel(config.text_config, linear_method)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None
) -> SamplerOutput: # noqa: E501
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
image_input: A batch of image inputs.
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
"""
if image_input is not None:
if list(image_input.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension "
f"plus "
f"{self.vision_language_config.image_input_shape[1:]}."
f" You supplied {image_input.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
if self.vision_tower is not None:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = self.vision_tower(image_input,
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if self.config.vision_feature_select_strategy == "default":
image_features = image_features[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
image_features = image_features
else:
raise ValueError(
f"Unexpected select feature strategy: "
f"{self.config.vision_feature_select_strategy}")
else:
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
_merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -303,6 +303,25 @@ class SequenceGroupState: ...@@ -303,6 +303,25 @@ class SequenceGroupState:
generator: Optional = None generator: Optional = None
class MultiModalData:
"""Multi modal request.
Args:
type: The data type.
data: The actual data.
The required shape and semantic meaning of it depends on the vision
language config of the hosted model.
See `VisionLanguageConfig` in `config.py`.
"""
class Type(enum.Enum):
IMAGE = enum.auto()
def __init__(self, type: Type, data: "torch.Tensor"):
self.type = type
self.data = data
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt. """A group of sequences that are generated from the same prompt.
...@@ -312,6 +331,7 @@ class SequenceGroup: ...@@ -312,6 +331,7 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request. arrival_time: The arrival time of the request.
lora_request: LoRA request. lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
""" """
def __init__( def __init__(
...@@ -321,6 +341,7 @@ class SequenceGroup: ...@@ -321,6 +341,7 @@ class SequenceGroup:
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
...@@ -333,6 +354,7 @@ class SequenceGroup: ...@@ -333,6 +354,7 @@ class SequenceGroup:
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState() self.state = SequenceGroupState()
self.multi_modal_data = multi_modal_data
@property @property
def prompt(self) -> str: def prompt(self) -> str:
...@@ -450,6 +472,7 @@ class SequenceGroupMetadata: ...@@ -450,6 +472,7 @@ class SequenceGroupMetadata:
numbers) numbers)
state: Internal state tied to this sequence group. state: Internal state tied to this sequence group.
lora_request: LoRA request. lora_request: LoRA request.
multi_modal_data: Multi modal data.
""" """
def __init__( def __init__(
...@@ -462,6 +485,7 @@ class SequenceGroupMetadata: ...@@ -462,6 +485,7 @@ class SequenceGroupMetadata:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None, computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None, state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
...@@ -470,6 +494,7 @@ class SequenceGroupMetadata: ...@@ -470,6 +494,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables self.block_tables = block_tables
self.lora_request = lora_request self.lora_request = lora_request
self.computed_block_nums = computed_block_nums self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state self.state = SequenceGroupState() if state is None else state
@property @property
......
...@@ -40,3 +40,17 @@ def get_config(model: str, ...@@ -40,3 +40,17 @@ def get_config(model: str,
revision=revision, revision=revision,
code_revision=code_revision) code_revision=code_revision)
return config return config
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
else:
return config
...@@ -377,6 +377,16 @@ class CudaMemoryProfiler: ...@@ -377,6 +377,16 @@ class CudaMemoryProfiler:
gc.collect() gc.collect()
def str_to_int_tuple(s: str) -> Tuple[int]:
"""Convert a string to a tuple of integers."""
try:
return tuple(map(int, s.split(",")))
except ValueError as e:
raise ValueError(
"String must be a series of integers separated by commas "
f"(e.g., 1, 2, 3). Given input: {s}") from e
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]: def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len assert len(x) <= max_len
return x + [pad] * (max_len - len(x)) return x + [pad] * (max_len - len(x))
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -21,7 +21,8 @@ from vllm.model_executor.parallel_utils.communication_op import ( ...@@ -21,7 +21,8 @@ from vllm.model_executor.parallel_utils.communication_op import (
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce) with_cupy_nccl_for_all_reduce)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad, is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim) maybe_expand_dim)
...@@ -49,6 +50,7 @@ class ModelRunner: ...@@ -49,6 +50,7 @@ class ModelRunner:
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
...@@ -83,17 +85,20 @@ class ModelRunner: ...@@ -83,17 +85,20 @@ class ModelRunner:
self.graph_block_tables = None # Set after initial profiling. self.graph_block_tables = None # Set after initial profiling.
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.vision_language_config = vision_language_config
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None) self.model_config.dtype if model_config is not None else None)
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model(self.model_config, self.model = get_model(
self.device_config, self.model_config,
lora_config=self.lora_config, self.device_config,
parallel_config=self.parallel_config, lora_config=self.lora_config,
scheduler_config=self.scheduler_config) vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took " logger.info(f"Loading model weights took "
...@@ -130,7 +135,8 @@ class ModelRunner: ...@@ -130,7 +135,8 @@ class ModelRunner:
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
List[int], List[int], List[int], Set[LoRARequest]]: List[int], List[int], List[int], Set[LoRARequest],
torch.Tensor]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
...@@ -143,6 +149,7 @@ class ModelRunner: ...@@ -143,6 +149,7 @@ class ModelRunner:
context_lens: List[int] = [] context_lens: List[int] = []
subquery_lens: List[int] = [] subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -188,6 +195,10 @@ class ModelRunner: ...@@ -188,6 +195,10 @@ class ModelRunner:
(prompt_len - computed_len (prompt_len - computed_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
...@@ -236,6 +247,16 @@ class ModelRunner: ...@@ -236,6 +247,16 @@ class ModelRunner:
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
# Prepare prefix block tables # Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
...@@ -291,7 +312,7 @@ class ModelRunner: ...@@ -291,7 +312,7 @@ class ModelRunner:
) )
return (input_tokens, input_positions, attn_metadata, prompt_lens, return (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) lora_requests, multi_modal_input)
def _prepare_decode( def _prepare_decode(
self, self,
...@@ -525,7 +546,7 @@ class ModelRunner: ...@@ -525,7 +546,7 @@ class ModelRunner:
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping]: Set[int], LoRAMapping, torch.Tensor]:
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
...@@ -534,13 +555,15 @@ class ModelRunner: ...@@ -534,13 +555,15 @@ class ModelRunner:
if is_prompt: if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens, (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_prompt(seq_group_metadata_list) lora_requests, multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, attn_metadata, (input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list) lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] prompt_lens = []
subquery_lens = None subquery_lens = None
multi_modal_input = None
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens) subquery_lens)
...@@ -561,6 +584,7 @@ class ModelRunner: ...@@ -561,6 +584,7 @@ class ModelRunner:
sampling_metadata.selected_token_indices, sampling_metadata.selected_token_indices,
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input,
} }
metadata_dict.update(attn_metadata.asdict_zerocopy()) metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
...@@ -572,6 +596,7 @@ class ModelRunner: ...@@ -572,6 +596,7 @@ class ModelRunner:
"selected_token_indices") "selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping") lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests") lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict) attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
...@@ -584,7 +609,8 @@ class ModelRunner: ...@@ -584,7 +609,8 @@ class ModelRunner:
) )
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping) sampling_metadata, lora_requests, lora_mapping,
multi_modal_input)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
...@@ -593,8 +619,8 @@ class ModelRunner: ...@@ -593,8 +619,8 @@ class ModelRunner:
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata, (input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_requests, lora_mapping, multi_modal_input
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) ) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config: if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping) self.set_active_loras(lora_requests, lora_mapping)
...@@ -605,12 +631,15 @@ class ModelRunner: ...@@ -605,12 +631,15 @@ class ModelRunner:
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[graph_batch_size]
else: else:
model_executable = self.model model_executable = self.model
hidden_states = model_executable( execute_model_kwargs = {
input_ids=input_tokens, "input_ids": input_tokens,
positions=input_positions, "positions": input_positions,
kv_caches=kv_caches, "kv_caches": kv_caches,
attn_metadata=attn_metadata, "attn_metadata": attn_metadata,
) }
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
...@@ -658,10 +687,22 @@ class ModelRunner: ...@@ -658,10 +687,22 @@ class ModelRunner:
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for vision encoding, which needs
# to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
if self.vision_language_config:
max_num_seqs = min(
max_num_seqs,
int(max_num_batched_tokens /
self.vision_language_config.image_feature_size))
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len) seq_data, fake_multi_modal_input = _prepare_fake_inputs(
seq_len, self.vision_language_config)
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
...@@ -670,6 +711,7 @@ class ModelRunner: ...@@ -670,6 +711,7 @@ class ModelRunner:
block_tables=None, block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id] lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None, if dummy_lora_requests_per_seq else None,
multi_modal_data=fake_multi_modal_input,
) )
seqs.append(seq) seqs.append(seq)
...@@ -831,6 +873,7 @@ class CUDAGraphRunner: ...@@ -831,6 +873,7 @@ class CUDAGraphRunner:
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
memory_pool, memory_pool,
**kwargs,
) -> None: ) -> None:
assert self.graph is None assert self.graph is None
# Run the model once without capturing the graph. # Run the model once without capturing the graph.
...@@ -842,6 +885,7 @@ class CUDAGraphRunner: ...@@ -842,6 +885,7 @@ class CUDAGraphRunner:
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
**kwargs,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -856,6 +900,7 @@ class CUDAGraphRunner: ...@@ -856,6 +900,7 @@ class CUDAGraphRunner:
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
**kwargs,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -877,6 +922,7 @@ class CUDAGraphRunner: ...@@ -877,6 +922,7 @@ class CUDAGraphRunner:
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them. # KV caches are fixed tensors, so we don't need to copy them.
del kv_caches del kv_caches
...@@ -922,3 +968,21 @@ def _get_graph_batch_size(batch_size: int) -> int: ...@@ -922,3 +968,21 @@ def _get_graph_batch_size(batch_size: int) -> int:
else: else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _prepare_fake_inputs(
seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
"""Prepare fake inputs for profile run."""
if vision_language_config:
prompt_tokens = [
vision_language_config.image_token_id
] * vision_language_config.image_feature_size + [0] * (
seq_len - vision_language_config.image_feature_size)
fake_image_input = MultiModalData(
type=MultiModalData.Type.IMAGE,
data=torch.zeros(vision_language_config.image_input_shape,
dtype=torch.float16))
else:
prompt_tokens = [0] * seq_len
fake_image_input = None
return SequenceData(prompt_tokens), fake_image_input
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils import cupy_utils
...@@ -39,6 +39,7 @@ class Worker: ...@@ -39,6 +39,7 @@ class Worker:
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
...@@ -54,13 +55,20 @@ class Worker: ...@@ -54,13 +55,20 @@ class Worker:
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config, self.vision_language_config = vision_language_config
parallel_config, if self.vision_language_config:
scheduler_config, assert not self.lora_config, (
device_config, "To be tested: vision language model with LoRA settings.")
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype, self.model_runner = ModelRunner(
is_driver_worker=is_driver_worker) model_config,
parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment