Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
......@@ -344,14 +344,5 @@ class OrionForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -14,10 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
import regex as re
import torch
import torch.nn as nn
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
......
......@@ -1228,9 +1228,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
loader = AutoWeightsLoader(self)
loader = AutoWeightsLoader(self, skip_substrs=["lora"])
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
......
......@@ -660,8 +660,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -9,7 +9,9 @@ from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import PixtralVisionConfig, TensorType
......@@ -39,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
......@@ -65,14 +67,14 @@ class PixtralImagePixelInputs(TypedDict):
"""
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
The result of stacking {attr}`ImageEncoding.tokens` from each prompt.
The result of stacking `ImageEncoding.tokens` from each prompt.
"""
class PixtralProcessorAdapter:
"""
Provide a HF-compatible interface for
{class}`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
......@@ -224,6 +226,28 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
num_images=num_images)
}
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
dummy_images = dummy_mm_data.get("image", [])
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=dummy_text),
*(ImageChunk(image=image) for image in dummy_images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
):
......@@ -275,8 +299,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
(
prompt_ids,
mm_kwargs,
mm_hashes,
_,
) = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -535,8 +535,5 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -530,8 +530,5 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -7,12 +7,12 @@
import copy
import math
import re
import unicodedata
from collections.abc import Collection, Mapping, Sequence, Set
from functools import lru_cache, partial
from typing import Callable, Literal, Optional, TypedDict, Union
import regex as re
import torch
from torch import nn
from torchvision import transforms
......@@ -382,7 +382,8 @@ def _get_tokenizer_without_image_pad(
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
"""
The logic of adding image pad tokens should only be applied in
{class}`QwenVLProcessor`, so they are patched out here.
[`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor],
so they are patched out here.
The definition of the wrapped tokenizer can be found here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
......
......@@ -80,6 +80,7 @@ _TEXT_GENERATION_MODELS = {
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
......@@ -208,6 +209,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
# [Encoder-decoder]
......@@ -382,7 +384,7 @@ class _ModelRegistry:
`model_cls` can be either:
- A {class}`torch.nn.Module` class directly referencing the model.
- A [`torch.nn.Module`][] class directly referencing the model.
- A string in the format `<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
......
......@@ -24,6 +24,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
......@@ -78,7 +79,7 @@ SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
......
......@@ -126,8 +126,9 @@ class SolarAttention(nn.Module):
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.head_dim = getattr(config, "head_dim", None)
if self.head_dim is None:
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
......@@ -500,14 +501,5 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -338,13 +338,5 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes=[
"rotary_emb.inv_freq", "rotary_emb.cos_cached",
"rotary_emb.sin_cached"
],
)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -349,8 +349,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
self,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes=([
"rotary_emb.inv_freq", "lm_head.weight"
] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]),
skip_prefixes=(["lm_head.weight"]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
......@@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from collections.abc import Iterable
from contextlib import nullcontext
from typing import Literal, Optional, Union
import regex as re
import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
......@@ -110,6 +111,33 @@ def replace_linear_class(
)
class ConfigOverride:
"""Context manager to temporarily override config attributes."""
def __init__(self, config: PretrainedConfig, **kwargs):
self.config = config
self.kwargs = kwargs
self.kwargs_original = {}
self.kwargs_delete = set()
def __enter__(self):
"""Override config attributes."""
for key, value in self.kwargs.items():
if not hasattr(self.config, key):
self.kwargs_delete.add(key)
self.kwargs_original[key] = getattr(self.config, key, None)
setattr(self.config, key, value)
return self.config
def __exit__(self, exc_type, exc_value, traceback):
"""Restore original config attributes."""
for key, value in self.kwargs_original.items():
if key in self.kwargs_delete:
delattr(self.config, key)
else:
setattr(self.config, key, value)
class TransformersModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -135,8 +163,17 @@ class TransformersModel(nn.Module):
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(config, "interleaved_sliding_window"):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)
# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
with torch.device("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
......@@ -262,9 +299,17 @@ class TransformersModel(nn.Module):
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention(
attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
sliding_window = None
if (hasattr(self.config, "interleaved_sliding_window")
and hasattr(self.config, "sliding_window_pattern")
and ((i + 1) % self.config.sliding_window_pattern > 0)):
sliding_window = self.config.interleaved_sliding_window
attention_instances[i] = Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
......@@ -273,9 +318,9 @@ class TransformersModel(nn.Module):
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{i}.attn")
for i in range(start, end)
}
return attention_instances
def init_buffers(self, module: nn.Module):
"""
......
......@@ -66,7 +66,7 @@ class WeightsMapper:
class AutoWeightsLoader:
"""
Helper class to load weights into a {class}`torch.nn.Module`. It is able
Helper class to load weights into a [`torch.nn.Module`][]. It is able
to automatically detect child modules and parameters while iterating over
the weights only once.
......@@ -80,18 +80,30 @@ class AutoWeightsLoader:
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
"""
# Models trained using early version ColossalAI
# may include these tensors in checkpoint. Skip them.
ROTARY_EMBEDS_UNUSED_WEIGHTS = [
"rotary_emb.inv_freq",
"rotary_emb.cos_cached",
"rotary_emb.sin_cached",
]
def __init__(
self,
module: nn.Module,
*,
skip_prefixes: Optional[list[str]] = None,
skip_substrs: Optional[list[str]] = None,
ignore_unexpected_prefixes: Optional[list[str]] = None,
) -> None:
super().__init__()
self.module = module
self.skip_prefixes = skip_prefixes or []
self.skip_substrs = skip_substrs or []
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
# update default skip_substrs
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
def _groupby_prefix(
self,
......@@ -119,7 +131,8 @@ class AutoWeightsLoader:
return ".".join((prefix, rest))
def _can_skip(self, qualname: str) -> bool:
return any(qualname.startswith(p) for p in self.skip_prefixes)
return (any(qualname.startswith(p) for p in self.skip_prefixes)
or any(substr in qualname for substr in self.skip_substrs))
def _can_ignore_unexpected(self, qualname: str) -> bool:
return any(
......@@ -257,6 +270,9 @@ class AutoWeightsLoader:
) -> set[str]:
if mapper is not None:
weights = mapper.apply(weights)
# filter out weights with first-prefix/substr to skip in name
weights = ((name, weight) for name, weight in weights
if not self._can_skip(name))
autoloaded_weights = set(self._load_module("", self.module, weights))
return autoloaded_weights
......
......@@ -8,12 +8,12 @@ from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global {class}`~MultiModalRegistry` is used by model runners to
dispatch data processing according to the target model.
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
is used by model runners to dispatch data processing according to the target
model.
:::{seealso}
{ref}`mm-processing`
:::
Info:
[mm_processing](../../../design/mm_processing.html)
"""
__all__ = [
......
......@@ -10,6 +10,7 @@ from blake3 import blake3
from PIL import Image
from vllm.logger import init_logger
from vllm.multimodal.image import convert_image_mode
if TYPE_CHECKING:
from vllm.inputs import TokensPrompt
......@@ -35,7 +36,8 @@ class MultiModalHasher:
return np.array(obj).tobytes()
if isinstance(obj, Image.Image):
return cls.item_to_bytes("image", np.array(obj.convert("RGBA")))
return cls.item_to_bytes(
"image", np.asarray(convert_image_mode(obj, "RGBA")))
if isinstance(obj, torch.Tensor):
return cls.item_to_bytes("tensor", obj.numpy())
if isinstance(obj, np.ndarray):
......@@ -43,7 +45,7 @@ class MultiModalHasher:
"ndarray", {
"dtype": obj.dtype.str,
"shape": obj.shape,
"data": obj.data.tobytes(),
"data": obj.tobytes(),
})
logger.warning(
......
......@@ -22,6 +22,25 @@ def rescale_image_size(image: Image.Image,
return image
# TODO: Support customizable background color to fill in.
def rgba_to_rgb(
image: Image.Image, background_color=(255, 255, 255)) -> Image.Image:
"""Convert an RGBA image to RGB with filled background color."""
assert image.mode == "RGBA"
converted = Image.new("RGB", image.size, background_color)
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return converted
def convert_image_mode(image: Image.Image, to_mode: str):
if image.mode == to_mode:
return image
elif image.mode == "RGBA" and to_mode == "RGB":
return rgba_to_rgb(image)
else:
return image.convert(to_mode)
class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, *, image_mode: str = "RGB") -> None:
......@@ -32,7 +51,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data))
image.load()
return image.convert(self.image_mode)
return convert_image_mode(image, self.image_mode)
def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(base64.b64decode(data))
......@@ -40,7 +59,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return image.convert(self.image_mode)
return convert_image_mode(image, self.image_mode)
def encode_base64(
self,
......@@ -51,7 +70,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
image = media
with BytesIO() as buffer:
image = image.convert(self.image_mode)
image = convert_image_mode(image, self.image_mode)
image.save(buffer, image_format)
data = buffer.getvalue()
......
......@@ -29,14 +29,14 @@ _T = TypeVar("_T")
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
"""
A {class}`transformers.image_utils.ImageInput` representing a single image
A `transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`.
"""
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
list[np.ndarray], list["torch.Tensor"]]
"""
A {class}`transformers.image_utils.VideoInput` representing a single video
A `transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`.
"""
......@@ -48,7 +48,7 @@ item, which can be passed to a HuggingFace `AudioProcessor`.
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
"""
A {class}`transformers.image_utils.ImageInput` representing a single image
A `transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
......@@ -58,7 +58,7 @@ these are directly passed to the model without HF processing.
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
"""
A {class}`transformers.image_utils.VideoInput` representing a single video
A `transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
......@@ -108,7 +108,8 @@ MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
A dictionary containing an entry for each modality type to input.
The built-in modalities are defined by {class}`MultiModalDataBuiltins`.
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
"""
......@@ -169,7 +170,8 @@ Uses a list instead of a tensor if the dimensions of each element do not match.
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
"""Equality check between {data}`NestedTensors` objects."""
"""Equality check between
[`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
if isinstance(a, torch.Tensor):
return isinstance(b, torch.Tensor) and torch.equal(a, b)
elif isinstance(b, torch.Tensor):
......@@ -189,7 +191,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
{meth}`MultiModalKwargs.batch`.
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
"""
......@@ -197,7 +199,7 @@ A dictionary containing nested tensors which have been batched via
class MultiModalFieldElem:
"""
Represents a keyword argument corresponding to a multi-modal item
in {class}`MultiModalKwargs`.
in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
"""
modality: str
......@@ -208,13 +210,15 @@ class MultiModalFieldElem:
key: str
"""
The key of this field in {class}`MultiModalKwargs`,
The key of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors
"""
The tensor data of this field in {class}`MultiModalKwargs`,
The tensor data of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
i.e. the value of the keyword argument to be passed to the model.
"""
......@@ -237,7 +241,8 @@ class MultiModalFieldElem:
class BaseMultiModalField(ABC):
"""
Defines how to interpret tensor data belonging to a keyword argument in
{class}`MultiModalKwargs` for multiple multi-modal items, and vice versa.
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
multi-modal items, and vice versa.
"""
def _field_factory(self, *, modality: str, key: str):
......@@ -262,10 +267,12 @@ class BaseMultiModalField(ABC):
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
"""
Construct {class}`MultiModalFieldElem` instances to represent
the provided data.
Construct
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
instances to represent the provided data.
This is the inverse of {meth}`reduce_data`.
This is the inverse of
[`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
"""
raise NotImplementedError
......@@ -275,9 +282,11 @@ class BaseMultiModalField(ABC):
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
"""
Merge the data from multiple instances of {class}`MultiModalFieldElem`.
Merge the data from multiple instances of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
This is the inverse of {meth}`build_elems`.
This is the inverse of
[`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
"""
field_types = [type(item.field) for item in elems]
if len(set(field_types)) > 1:
......@@ -289,9 +298,8 @@ class BaseMultiModalField(ABC):
@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
:::{seealso}
{func}`MultiModalFieldConfig.batched`
:::
Info:
[`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
"""
def build_elems(
......@@ -320,10 +328,9 @@ class MultiModalBatchedField(BaseMultiModalField):
@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
"""
:::{seealso}
{func}`MultiModalFieldConfig.flat`
{func}`MultiModalFieldConfig.flat_from_sizes`
:::
Info:
[`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
[`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
"""
slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
dim: int = 0
......@@ -363,9 +370,8 @@ class MultiModalFlatField(BaseMultiModalField):
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
"""
:::{seealso}
{func}`MultiModalFieldConfig.shared`
:::
Info:
[`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
"""
batch_size: int
......@@ -510,9 +516,8 @@ class MultiModalFieldConfig:
Element 3: [[C],[C]]
```
:::{seealso}
{func}`MultiModalFieldConfig.flat`
:::
Info:
[`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
"""
if size_per_item.ndim != 1:
......@@ -576,8 +581,10 @@ class MultiModalFieldConfig:
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
A collection of {class}`MultiModalFieldElem`
corresponding to a data item in {class}`MultiModalDataItems`.
A collection of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
corresponding to a data item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
@staticmethod
......@@ -596,11 +603,13 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
{meth}`~torch.nn.Module.forward`.
[`torch.nn.Module.forward`][].
The metadata `items` enables us to obtain the keyword arguments
corresponding to each data item in {class}`MultiModalDataItems`, via
{meth}`get_item` and {meth}`get_items`.
corresponding to each data item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via
[`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and
[`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items].
"""
@staticmethod
......@@ -639,7 +648,9 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
@staticmethod
def from_items(items: Sequence[MultiModalKwargsItem]):
"""Construct a new {class}`MultiModalKwargs` from multiple items."""
"""Construct a new
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
from multiple items."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
for key, elem in item.items():
......@@ -735,11 +746,17 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
dtype: Optional[torch.dtype] = None,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
def maybe_cast_dtype(x: torch.Tensor):
# This mimics the behavior of transformers.BatchFeature
return x.to(dtype=dtype) if x.is_floating_point() else x
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
# NOTE: Cast the dtype before sending it to device
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
json_inputs,
)
......@@ -804,7 +821,7 @@ A dictionary containing placeholder ranges for each modality.
class MultiModalInputs(TypedDict):
"""
Represents the outputs of
{class}`vllm.multimodal.processing.BaseMultiModalProcessor`,
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
ready to be passed to vLLM internals.
"""
......@@ -840,7 +857,8 @@ class MultiModalInputs(TypedDict):
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of {class}`vllm.multimodal.EncDecMultiModalProcessor`
Represents the outputs of
[`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
ready to be passed to vLLM internals.
"""
......
......@@ -28,7 +28,8 @@ else:
class ModalityDataItems(ABC, Generic[_T, _I]):
"""
Represents data items for a modality in {class}`MultiModalDataItems`.
Represents data items for a modality in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
def __init__(self, data: _T, modality: str) -> None:
......@@ -251,15 +252,15 @@ _D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
"""
As {data}`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized
such that each entry corresponds to a list.
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
normalized such that each entry corresponds to a list.
"""
def get_count(self, modality: str, *, strict: bool = True) -> int:
"""
Get the number of data items belonging to a modality.
If `strict=False`, return `0` instead of raising {exc}`KeyError`
If `strict=False`, return `0` instead of raising [`KeyError`][]
even if the modality is not found.
"""
if modality not in self:
......@@ -305,8 +306,8 @@ ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
class MultiModalDataParser:
"""
Parses {data}`~vllm.multimodal.inputs.MultiModalDataDict` into
{class}`MultiModalDataItems`.
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
Args:
target_sr (float, optional): Enables automatic resampling of audio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment