Unverified Commit 3f674a49 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126)

parent 70b746ef
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
TypedDict) Tuple, TypedDict)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -19,8 +19,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -19,8 +19,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -61,6 +60,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext): ...@@ -61,6 +60,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
def dummy_seq_data_for_chameleon( def dummy_seq_data_for_chameleon(
seq_len: int, seq_len: int,
num_images: int,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
...@@ -70,12 +70,14 @@ def dummy_seq_data_for_chameleon( ...@@ -70,12 +70,14 @@ def dummy_seq_data_for_chameleon(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size) token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)
def dummy_image_for_chameleon( def dummy_image_for_chameleon(
num_images: int,
*,
image_width_override: Optional[int] = None, image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None, image_height_override: Optional[int] = None,
): ):
...@@ -87,17 +89,20 @@ def dummy_image_for_chameleon( ...@@ -87,17 +89,20 @@ def dummy_image_for_chameleon(
height = image_height_override height = image_height_override
image = Image.new("RGB", (width, height), color=0) image = Image.new("RGB", (width, height), color=0)
return {"image": image} return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int): def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_chameleon( seq_data = dummy_seq_data_for_chameleon(
seq_len, seq_len,
num_images,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID, image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
) )
mm_data = dummy_image_for_chameleon() mm_data = dummy_image_for_chameleon(num_images)
return seq_data, mm_data return seq_data, mm_data
......
...@@ -43,6 +43,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: ...@@ -43,6 +43,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
def dummy_seq_data_for_clip( def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig, hf_config: CLIPVisionConfig,
seq_len: int, seq_len: int,
num_images: int,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
...@@ -52,13 +53,14 @@ def dummy_seq_data_for_clip( ...@@ -52,13 +53,14 @@ def dummy_seq_data_for_clip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size) token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)
def dummy_image_for_clip( def dummy_image_for_clip(
hf_config: CLIPVisionConfig, hf_config: CLIPVisionConfig,
num_images: int,
*, *,
image_width_override: Optional[int] = None, image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None, image_height_override: Optional[int] = None,
...@@ -70,7 +72,7 @@ def dummy_image_for_clip( ...@@ -70,7 +72,7 @@ def dummy_image_for_clip(
height = image_height_override height = image_height_override
image = Image.new("RGB", (width, height), color=0) image = Image.new("RGB", (width, height), color=0)
return {"image": image} return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_clip( def input_processor_for_clip(
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -29,8 +29,7 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -29,8 +29,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -94,27 +93,33 @@ def get_max_fuyu_image_tokens(ctx: InputContext): ...@@ -94,27 +93,33 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
return (ncol + 1) * nrow return (ncol + 1) * nrow
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int): def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size() ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx) image_feature_size = get_max_fuyu_image_tokens(ctx)
token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
token_ids += [0] * (seq_len - image_feature_size) token_ids = image_token_ids * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)
def dummy_image_for_fuyu( def dummy_image_for_fuyu(
num_images: int,
*,
image_width: int, image_width: int,
image_height: int, image_height: int,
): ):
image = Image.new("RGB", (image_width, image_height), color=0) image = Image.new("RGB", (image_width, image_height), color=0)
return {"image": image} return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int): def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len) mm_counts: Mapping[str, int]):
mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH, num_images = mm_counts["image"]
MAX_IMAGE_FEATURE_SIZE_HEIGHT) seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return seq_data, mm_data return seq_data, mm_data
......
...@@ -11,14 +11,11 @@ logger = init_logger(__name__) ...@@ -11,14 +11,11 @@ logger = init_logger(__name__)
@runtime_checkable @runtime_checkable
class SupportsMultiModal(Protocol): class SupportsMultiModal(Protocol):
""" """The interface required for all multi-modal models."""
The interface required for all multimodal (vision or audio) language
models.
"""
supports_multimodal: ClassVar[Literal[True]] = True supports_multimodal: ClassVar[Literal[True]] = True
""" """
A flag that indicates this model supports multimodal inputs. A flag that indicates this model supports multi-modal inputs.
Note: Note:
There is no need to redefine this flag if this class is in the There is no need to redefine this flag if this class is in the
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import itertools import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -230,7 +231,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -230,7 +231,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def input_mapper_for_internvl(ctx: InputContext, data: object): def input_mapper_for_internvl(ctx: InputContext, data: object):
hf_config = ctx.get_hf_config(PretrainedConfig) hf_config = ctx.get_hf_config()
use_thumbnail = hf_config.use_thumbnail use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch min_num = hf_config.min_dynamic_patch
...@@ -256,7 +257,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): ...@@ -256,7 +257,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
}) })
def dummy_data_for_internvl(ctx: InputContext, seq_len: int): def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_internvl_image_tokens(ctx) image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config model_config = ctx.model_config
...@@ -268,6 +271,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): ...@@ -268,6 +271,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images,
image_token_id=tokenizer.encode(IMG_CONTEXT, image_token_id=tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False)[0], add_special_tokens=False)[0],
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
...@@ -281,6 +285,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): ...@@ -281,6 +285,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
mm_data = dummy_image_for_clip( mm_data = dummy_image_for_clip(
vision_config, vision_config,
num_images,
image_width_override=max_image_width, image_width_override=max_image_width,
image_height_override=max_image_height, image_height_override=max_image_height,
) )
......
import itertools import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -9,8 +10,7 @@ from vllm.attention import AttentionMetadata ...@@ -9,8 +10,7 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -88,9 +88,11 @@ def get_max_llava_image_tokens(ctx: InputContext): ...@@ -88,9 +88,11 @@ def get_max_llava_image_tokens(ctx: InputContext):
raise ValueError(f"Unexpected select feature strategy: {strategy}") raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_data_for_llava(ctx: InputContext, seq_len: int): def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_max_llava_image_tokens(ctx) image_feature_size = get_max_llava_image_tokens(ctx)
...@@ -98,21 +100,23 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int): ...@@ -98,21 +100,23 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
mm_data = dummy_image_for_clip(vision_config) mm_data = dummy_image_for_clip(vision_config, num_images)
return seq_data, mm_data return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig): elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip( seq_data = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
mm_data = dummy_image_for_siglip(vision_config) mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
......
import itertools import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -13,8 +14,7 @@ from vllm.attention import AttentionMetadata ...@@ -13,8 +14,7 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -158,9 +158,11 @@ def get_max_llava_next_image_tokens(ctx: InputContext): ...@@ -158,9 +158,11 @@ def get_max_llava_next_image_tokens(ctx: InputContext):
) )
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextConfig) hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_max_llava_next_image_tokens(ctx) image_feature_size = get_max_llava_next_image_tokens(ctx)
...@@ -168,12 +170,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): ...@@ -168,12 +170,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
vision_config, vision_config,
seq_len, seq_len,
num_images,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
mm_data = dummy_image_for_clip( mm_data = dummy_image_for_clip(
vision_config, vision_config,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
...@@ -183,12 +187,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): ...@@ -183,12 +187,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_siglip( seq_data = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
mm_data = dummy_image_for_siglip( mm_data = dummy_image_for_siglip(
vision_config, vision_config,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
import math import math
import re import re
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict, from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
Union) TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -42,8 +42,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -42,8 +42,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
...@@ -408,22 +407,24 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): ...@@ -408,22 +407,24 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
return getattr(hf_config, "query_num", 64) return getattr(hf_config, "query_num", 64)
def dummy_seq_data_for_minicpmv(seq_len: int): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
token_ids = [0] * seq_len token_ids = [0] * seq_len
return SequenceData(token_ids) return SequenceData(token_ids)
def dummy_image_for_minicpmv(hf_config: PretrainedConfig): def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
width = height = hf_config.image_size width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0) image = Image.new("RGB", (width, height), color=0)
return {"image": image} return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config() hf_config = ctx.get_hf_config()
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_minicpmv(seq_len) seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(hf_config) mm_data = dummy_image_for_minicpmv(hf_config, num_images)
return seq_data, mm_data return seq_data, mm_data
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -9,8 +10,7 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -9,8 +10,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaModel from vllm.model_executor.models.gemma import GemmaModel
...@@ -57,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext): ...@@ -57,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext):
return get_max_siglip_image_tokens(vision_config) return get_max_siglip_image_tokens(vision_config)
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int): def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig) hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_siglip( seq_data = dummy_seq_data_for_siglip(
vision_config, vision_config,
seq_len, seq_len,
num_images,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
) )
mm_data = dummy_image_for_siglip(vision_config) mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data return seq_data, mm_data
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
import re import re
from functools import lru_cache from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -28,8 +29,7 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig ...@@ -28,8 +29,7 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -347,18 +347,22 @@ def get_max_phi3v_image_tokens(ctx: InputContext): ...@@ -347,18 +347,22 @@ def get_max_phi3v_image_tokens(ctx: InputContext):
) )
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_phi3v_image_tokens(ctx) image_feature_size = get_max_phi3v_image_tokens(ctx)
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len, seq_len,
num_images,
image_token_id=_IMAGE_TOKEN_ID, image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
mm_data = dummy_image_for_clip( mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
......
...@@ -52,6 +52,7 @@ def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int: ...@@ -52,6 +52,7 @@ def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
def dummy_seq_data_for_siglip( def dummy_seq_data_for_siglip(
hf_config: SiglipVisionConfig, hf_config: SiglipVisionConfig,
seq_len: int, seq_len: int,
num_images: int,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
...@@ -61,13 +62,14 @@ def dummy_seq_data_for_siglip( ...@@ -61,13 +62,14 @@ def dummy_seq_data_for_siglip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size) token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)
def dummy_image_for_siglip( def dummy_image_for_siglip(
hf_config: SiglipVisionConfig, hf_config: SiglipVisionConfig,
num_images: int,
*, *,
image_width_override: Optional[int] = None, image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None, image_height_override: Optional[int] = None,
...@@ -79,7 +81,7 @@ def dummy_image_for_siglip( ...@@ -79,7 +81,7 @@ def dummy_image_for_siglip(
height = image_height_override height = image_height_override
image = Image.new("RGB", (width, height), color=0) image = Image.new("RGB", (width, height), color=0)
return {"image": image} return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_siglip( def input_processor_for_siglip(
......
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from typing import Any, Callable, Dict, List, Optional from typing import Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
import numpy as np import numpy as np
import torch import torch
...@@ -116,17 +116,30 @@ class MultiModalInputs(_MultiModalInputsBase): ...@@ -116,17 +116,30 @@ class MultiModalInputs(_MultiModalInputsBase):
batched_inputs) batched_inputs)
_T = TypeVar("_T")
MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.
The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""
@final
class MultiModalDataBuiltins(TypedDict, total=False): class MultiModalDataBuiltins(TypedDict, total=False):
"""Modality types that are predefined by vLLM.""" """Modality types that are predefined by vLLM."""
image: Image.Image image: MultiModalData[Image.Image]
"""The input image.""" """The input image(s)."""
audio: Tuple[np.ndarray, Union[int, float]] audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
"""The input audio and its sampling rate.""" """The input audio item(s) and corresponding sampling rate(s)."""
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] MultiModalDataDict = Union[MultiModalDataBuiltins,
Mapping[str, MultiModalData[object]]]
""" """
A dictionary containing an item for each modality type to input. A dictionary containing an item for each modality type to input.
...@@ -137,7 +150,8 @@ Note: ...@@ -137,7 +150,8 @@ Note:
Read more on that :ref:`here <adding_multimodal_plugin>`. Read more on that :ref:`here <adding_multimodal_plugin>`.
""" """
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs] MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalInputs]
""" """
Return a dictionary to be passed as keyword arguments to Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
...@@ -181,8 +195,11 @@ class MultiModalPlugin(ABC): ...@@ -181,8 +195,11 @@ class MultiModalPlugin(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def _default_input_mapper(self, ctx: InputContext, def _default_input_mapper(
data: object) -> MultiModalInputs: self,
ctx: InputContext,
data: MultiModalData[object],
) -> MultiModalInputs:
""" """
Return a dictionary to be passed as keyword arguments to Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to :meth:`~torch.nn.Module.forward`. This is similar in concept to
...@@ -225,7 +242,7 @@ class MultiModalPlugin(ABC): ...@@ -225,7 +242,7 @@ class MultiModalPlugin(ABC):
return wrapper return wrapper
def map_input(self, model_config: ModelConfig, def map_input(self, model_config: ModelConfig,
data: object) -> MultiModalInputs: data: MultiModalData[object]) -> MultiModalInputs:
""" """
Transform the data into a dictionary of model inputs using the Transform the data into a dictionary of model inputs using the
input mapper registered for that model. input mapper registered for that model.
...@@ -254,8 +271,8 @@ class MultiModalPlugin(ABC): ...@@ -254,8 +271,8 @@ class MultiModalPlugin(ABC):
@abstractmethod @abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
""" """
Calculate the maximum number of multimodal tokens input to the language Calculate the maximum number of tokens, corresponding to a single
model. This does not include tokens that correspond to the input text. instance of multimodal data, that are passed to the language model.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -269,8 +286,9 @@ class MultiModalPlugin(ABC): ...@@ -269,8 +286,9 @@ class MultiModalPlugin(ABC):
max_mm_tokens: Optional[MultiModalTokensCalc] = None, max_mm_tokens: Optional[MultiModalTokensCalc] = None,
): ):
""" """
Register the maximum number of multi-modal tokens input to the Register the maximum number of tokens, corresponding to a single
language model for a model class. instance of multimodal data, that are passed to the language model
for a model class.
If `None` is provided, then the default calculation is used instead. If `None` is provided, then the default calculation is used instead.
......
...@@ -11,7 +11,7 @@ from vllm.transformers_utils.image_processor import get_image_processor ...@@ -11,7 +11,7 @@ from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .base import MultiModalInputs, MultiModalPlugin from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -110,8 +110,11 @@ class ImagePlugin(MultiModalPlugin): ...@@ -110,8 +110,11 @@ class ImagePlugin(MultiModalPlugin):
model_config.model, model_config.model,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code)
def _default_input_mapper(self, ctx: InputContext, def _default_input_mapper(
data: object) -> MultiModalInputs: self,
ctx: InputContext,
data: MultiModalData[object],
) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
# PIL image # PIL image
......
import functools import functools
from typing import Dict, Optional, Sequence from collections import UserDict
from typing import Dict, Mapping, Optional, Sequence
import torch from vllm.config import ModelConfig, MultiModalConfig
from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .audio import AudioPlugin from .audio import AudioPlugin
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin, MultiModalTokensCalc) MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
from .image import ImagePlugin from .image import ImagePlugin
logger = init_logger(__name__) logger = init_logger(__name__)
class _MultiModalLimits(UserDict):
"""
Wraps `_limits_by_model` for a more informative error message
when attempting to access a model that does not exist.
"""
def __getitem__(self, key: ModelConfig) -> Dict[str, int]:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
"forget to call `init_mm_limits_per_prompt`?")
raise KeyError(msg) from exc
class MultiModalRegistry: class MultiModalRegistry:
""" """
A registry that dispatches data processing to the A registry that dispatches data processing to the
...@@ -28,6 +42,11 @@ class MultiModalRegistry: ...@@ -28,6 +42,11 @@ class MultiModalRegistry:
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins} self._plugins = {p.get_data_key(): p for p in plugins}
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
self._limits_by_model = _MultiModalLimits()
def register_plugin(self, plugin: MultiModalPlugin) -> None: def register_plugin(self, plugin: MultiModalPlugin) -> None:
""" """
Register a multi-modal plugin so it can be recognized by vLLM. Register a multi-modal plugin so it can be recognized by vLLM.
...@@ -86,13 +105,24 @@ class MultiModalRegistry: ...@@ -86,13 +105,24 @@ class MultiModalRegistry:
via the input mapper registered for that model. via the input mapper registered for that model.
See :meth:`MultiModalPlugin.map_input` for more details. See :meth:`MultiModalPlugin.map_input` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
""" """
merged_dict: Dict[str, torch.Tensor] = {} merged_dict: Dict[str, NestedTensors] = {}
for data_key, data_value in data.items(): for data_key, data_value in data.items():
input_dict = self._get_plugin(data_key) \ plugin = self._get_plugin(data_key)
.map_input(model_config, data_value)
num_items = len(data_value) if isinstance(data_value, list) else 1
max_items = self._limits_by_model[model_config][data_key]
if num_items > max_items:
raise ValueError(
f"You set {data_key}={max_items} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but found {num_items} items "
"in the same prompt.")
input_dict = plugin.map_input(model_config, data_value)
for input_key, input_tensor in input_dict.items(): for input_key, input_tensor in input_dict.items():
if input_key in merged_dict: if input_key in merged_dict:
raise ValueError(f"The input mappers (keys={set(data)}) " raise ValueError(f"The input mappers (keys={set(data)}) "
...@@ -115,8 +145,9 @@ class MultiModalRegistry: ...@@ -115,8 +145,9 @@ class MultiModalRegistry:
max_mm_tokens: Optional[MultiModalTokensCalc] = None, max_mm_tokens: Optional[MultiModalTokensCalc] = None,
): ):
""" """
Register the maximum number of tokens, belonging to a Register the maximum number of tokens, corresponding to a single
specific modality, input to the language model for a model class. instance of multimodal data belonging to a specific modality, that are
passed to the language model for a model class.
""" """
return self._get_plugin(data_type_key) \ return self._get_plugin(data_type_key) \
.register_max_multimodal_tokens(max_mm_tokens) .register_max_multimodal_tokens(max_mm_tokens)
...@@ -126,8 +157,8 @@ class MultiModalRegistry: ...@@ -126,8 +157,8 @@ class MultiModalRegistry:
max_mm_tokens: Optional[MultiModalTokensCalc] = None, max_mm_tokens: Optional[MultiModalTokensCalc] = None,
): ):
""" """
Register the maximum number of image tokens Register the maximum number of image tokens, corresponding to a single
input to the language model for a model class. image, that are passed to the language model for a model class.
""" """
return self.register_max_multimodal_tokens("image", max_mm_tokens) return self.register_max_multimodal_tokens("image", max_mm_tokens)
...@@ -137,7 +168,61 @@ class MultiModalRegistry: ...@@ -137,7 +168,61 @@ class MultiModalRegistry:
for profiling the memory usage of a model. for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
limits_per_plugin = self._limits_by_model[model_config]
return sum((limits_per_plugin[key] *
plugin.get_max_multimodal_tokens(model_config))
for key, plugin in self._plugins.items())
def init_mm_limits_per_prompt(
self,
model_config: ModelConfig,
multimodal_config: Optional[MultiModalConfig],
) -> None:
"""
Initialize the maximum number of multi-modal input instances for each
modality that are allowed per prompt for a model class.
"""
if model_config in self._limits_by_model:
logger.warning(
"`mm_limits` has already been set for model=%s, and will "
"be overwritten by the new values.", model_config.model)
if multimodal_config is None:
limits_per_plugin = self._disabled_limits_per_plugin
else:
config_limits_per_plugin = multimodal_config.limit_per_prompt
extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
if extra_keys:
logger.warning(
"Detected extra keys in `--limit-mm-per-prompt` which "
"are not registered as multi-modal plugins: %s. "
"They will be ignored.", extra_keys)
# NOTE: Currently the default is set to 1 for each plugin
# TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs
limits_per_plugin = {
key: config_limits_per_plugin.get(key, 1)
for key in self._plugins
}
self._limits_by_model[model_config] = limits_per_plugin
def get_mm_limits_per_prompt(
self,
model_config: ModelConfig,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
""" """
return sum( return self._limits_by_model[model_config]
plugin.get_max_multimodal_tokens(model_config)
for plugin in self._plugins.values())
...@@ -13,7 +13,6 @@ import threading ...@@ -13,7 +13,6 @@ import threading
import uuid import uuid
import warnings import warnings
from asyncio import FIRST_COMPLETED, ensure_future from asyncio import FIRST_COMPLETED, ensure_future
from collections import defaultdict
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
...@@ -760,16 +759,6 @@ class CudaMemoryProfiler: ...@@ -760,16 +759,6 @@ class CudaMemoryProfiler:
gc.collect() gc.collect()
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
"""Convert a string to a tuple of integers."""
try:
return tuple(map(int, s.split(",")))
except ValueError as e:
raise ValueError(
"String must be a series of integers separated by commas "
f"(e.g., 1, 2, 3). Given input: {s}") from e
def make_ndarray_with_pad( def make_ndarray_with_pad(
x: List[List[T]], x: List[List[T]],
pad: T, pad: T,
...@@ -863,23 +852,6 @@ def is_list_of( ...@@ -863,23 +852,6 @@ def is_list_of(
assert_never(check) assert_never(check)
def merge_dicts(dict1: Dict[K, List[T]],
dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict: Dict[K, List[T]] = defaultdict(list)
for key, value in dict1.items():
merged_dict[key].extend(value)
for key, value in dict2.items():
merged_dict[key].extend(value)
return dict(merged_dict)
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
Tuple["JSONTree[T]", ...], T] Tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable.""" """A nested JSON structure where the leaves need not be JSON-serializable."""
......
...@@ -12,9 +12,10 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, ...@@ -12,9 +12,10 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig, ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig) ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -83,6 +84,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -83,6 +84,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None, observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
''' '''
EncoderDecoderModelRunner constructor. EncoderDecoderModelRunner constructor.
...@@ -271,6 +274,16 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -271,6 +274,16 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
model_config = self.model_config model_config = self.model_config
mm_config = self.multimodal_config
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0:
raise NotImplementedError(
"Multi-modal encoder-decoder models are not supported yet")
batch_size = 0 batch_size = 0
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
...@@ -278,8 +291,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -278,8 +291,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
seq_data, _ = INPUT_REGISTRY \ seq_data, _ = input_registry \
.dummy_data_for_profiling(model_config, seq_len) .dummy_data_for_profiling(model_config, seq_len, mm_registry)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, ( assert len(seq_data.prompt_token_ids) >= seq_len, (
......
...@@ -31,7 +31,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -31,7 +31,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig) ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -43,7 +43,7 @@ from vllm.model_executor.models.interfaces import (supports_lora, ...@@ -43,7 +43,7 @@ from vllm.model_executor.models.interfaces import (supports_lora,
supports_multimodal) supports_multimodal)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs, MultiModalRegistry)
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import ( from vllm.prompt_adapter.worker_manager import (
...@@ -807,6 +807,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -807,6 +807,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
multimodal_config: Optional[MultiModalConfig] = None, multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None, observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
...@@ -860,8 +862,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -860,8 +862,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
) if num_attn_heads else None ) if num_attn_heads else None
# Multi-modal data support # Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ self.input_registry = input_registry
.create_input_mapper(self.model_config) self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config)
# Lazy initialization # Lazy initialization
self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
...@@ -902,7 +906,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -902,7 +906,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
assert supports_lora(self.model), "Model does not support LoRA" assert supports_lora(self.model), "Model does not support LoRA"
assert not supports_multimodal( assert not supports_multimodal(
self.model self.model
), "To be tested: multimodal language model with LoRA settings." ), "To be tested: Multi-modal model with LoRA settings."
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
...@@ -1046,17 +1050,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1046,17 +1050,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for vision encoding, which needs # Additional GPU memory may be needed for multi-modal encoding, which
# to be accounted for when calculating the GPU blocks for # needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager. # vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption, # To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number # the number of seqs (batch_size) is chosen to maximize the number
# of images processed. # of images processed.
model_config = self.model_config model_config = self.model_config
mm_config = self.multimodal_config
if supports_multimodal(self.model): input_registry = self.input_registry
max_mm_tokens = MULTIMODAL_REGISTRY \ mm_registry = self.mm_registry
.get_max_multimodal_tokens(model_config) mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs, max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens) max_num_batched_tokens // max_mm_tokens)
...@@ -1074,13 +1082,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1074,13 +1082,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ seq_data, dummy_multi_modal_data = input_registry \
.dummy_data_for_profiling(model_config, seq_len) .dummy_data_for_profiling(model_config, seq_len, mm_registry)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
......
...@@ -9,12 +9,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -9,12 +9,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_multimodal
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs, MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -89,6 +88,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -89,6 +88,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -120,8 +121,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -120,8 +121,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
) )
# Multi-modal data support # Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ self.input_registry = input_registry
.create_input_mapper(self.model_config) self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config)
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
...@@ -157,17 +160,21 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -157,17 +160,21 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for vision encoding, which needs # Additional GPU memory may be needed for multi-modal encoding, which
# to be accounted for when calculating the GPU blocks for # needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager. # vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption, # To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number # the number of seqs (batch_size) is chosen to maximize the number
# of images processed. # of images processed.
model_config = self.model_config model_config = self.model_config
mm_config = self.multimodal_config
if supports_multimodal(self.model): input_registry = self.input_registry
max_mm_tokens = MULTIMODAL_REGISTRY \ mm_registry = self.mm_registry
.get_max_multimodal_tokens(model_config) mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs, max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens) max_num_batched_tokens // max_mm_tokens)
...@@ -183,13 +190,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -183,13 +190,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ seq_data, dummy_multi_modal_data = input_registry \
.dummy_data_for_profiling(model_config, seq_len) .dummy_data_for_profiling(model_config, seq_len, mm_registry)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment