Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
...@@ -330,8 +330,6 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -330,8 +330,6 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
class VoxtralForConditionalGeneration( class VoxtralForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
): ):
merge_by_field_config = True
supported_languages = ISO639_1_SUPPORTED_LANGS supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = { packed_modules_mapping = {
......
...@@ -775,7 +775,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo ...@@ -775,7 +775,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
class WhisperForConditionalGeneration( class WhisperForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal nn.Module, SupportsTranscription, SupportsMultiModal
): ):
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"self_attn.qkv_proj": [ "self_attn.qkv_proj": [
"self_attn.q_proj", "self_attn.q_proj",
......
...@@ -50,6 +50,31 @@ def set_weight_attrs( ...@@ -50,6 +50,31 @@ def set_weight_attrs(
setattr(weight, key, value) setattr(weight, key, value)
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
"""
Replace a parameter of a layer while maintaining the ability to reload the weight.
Called within implementations of the `process_weights_after_loading` method.
This function should not be called on weights which are tied/shared
Args:
layer: Layer containing parameter to replace
param_name: Name of parameter to replace
new_data: New data of the new parameter
"""
# should not be used on a tied/shared param
if isinstance(new_data, torch.nn.Parameter):
new_data = new_data.data
new_param = torch.nn.Parameter(new_data, requires_grad=False)
old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
if old_param is not None and hasattr(old_param, "weight_loader"):
weight_loader = old_param.weight_loader
set_weight_attrs(new_param, {"weight_loader": weight_loader})
setattr(layer, param_name, new_param)
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = getattr(model, "packed_modules_mapping", None) parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
......
...@@ -11,6 +11,7 @@ import pybase64 ...@@ -11,6 +11,7 @@ import pybase64
import torch import torch
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
from vllm.utils.serial_utils import tensor2base64
from .base import MediaIO from .base import MediaIO
...@@ -135,8 +136,4 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]): ...@@ -135,8 +136,4 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
return torch.load(filepath, weights_only=True) return torch.load(filepath, weights_only=True)
def encode_base64(self, media: torch.Tensor) -> str: def encode_base64(self, media: torch.Tensor) -> str:
buffer = BytesIO() return tensor2base64(media)
torch.save(media, buffer)
buffer.seek(0)
binary_data = buffer.read()
return pybase64.b64encode(binary_data).decode("utf-8")
...@@ -2,12 +2,42 @@ ...@@ -2,12 +2,42 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generic, TypeVar from typing import Generic, TypeVar
import numpy as np
_T = TypeVar("_T") _T = TypeVar("_T")
@dataclass
class MediaWithBytes(Generic[_T]):
"""
Wrapper that couples a media object with its original encoded bytes.
This ensures the raw bytes and media object remain synchronized,
preventing cache corruption from in-place modifications.
The wrapper delegates attribute access to the underlying media object,
making it behave transparently like the wrapped type (e.g., PIL.Image).
NOTE: Currently, this wrapper is used only for the image modality.
"""
media: _T
original_bytes: bytes
def __array__(self, *args, **kwargs) -> np.ndarray:
"""Allow np.array(obj) to return np.array(obj.media)."""
return np.array(self.media, *args, **kwargs)
def __getattr__(self, name: str):
"""Delegate attribute access to the underlying media object."""
# This is only called when the attribute is not found on self
return getattr(self.media, name)
class MediaIO(ABC, Generic[_T]): class MediaIO(ABC, Generic[_T]):
@abstractmethod @abstractmethod
def load_bytes(self, data: bytes) -> _T: def load_bytes(self, data: bytes) -> _T:
......
...@@ -25,7 +25,6 @@ from .inputs import ( ...@@ -25,7 +25,6 @@ from .inputs import (
MultiModalBatchedField, MultiModalBatchedField,
MultiModalFeatureSpec, MultiModalFeatureSpec,
MultiModalFieldElem, MultiModalFieldElem,
MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems, MultiModalKwargsItems,
NestedTensors, NestedTensors,
...@@ -90,7 +89,6 @@ MultiModalCacheValue: TypeAlias = ( ...@@ -90,7 +89,6 @@ MultiModalCacheValue: TypeAlias = (
| MultiModalProcessorCacheItemMetadata | MultiModalProcessorCacheItemMetadata
| MultiModalKwargsItems | MultiModalKwargsItems
| MultiModalKwargsItem | MultiModalKwargsItem
| MultiModalKwargs
| Mapping[str, NestedTensors] | Mapping[str, NestedTensors]
) )
...@@ -108,12 +106,7 @@ class MultiModalCache: ...@@ -108,12 +106,7 @@ class MultiModalCache:
# These are not subclasses of dict # These are not subclasses of dict
if isinstance( if isinstance(
leaf, leaf,
( (MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem),
MultiModalKwargs,
MultiModalKwargsItems,
MultiModalKwargsItem,
MultiModalFieldElem,
),
): ):
return cls.get_item_size(leaf.data) # type: ignore return cls.get_item_size(leaf.data) # type: ignore
......
...@@ -12,6 +12,8 @@ from PIL import Image ...@@ -12,6 +12,8 @@ from PIL import Image
from vllm.logger import init_logger from vllm.logger import init_logger
from .base import MediaWithBytes
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,14 +33,26 @@ class MultiModalHasher: ...@@ -31,14 +33,26 @@ class MultiModalHasher:
if Image.ExifTags.Base.ImageID in exif and isinstance( if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID exif[Image.ExifTags.Base.ImageID], uuid.UUID
): ):
# If the image has exif ImageID tag, use that
return (exif[Image.ExifTags.Base.ImageID].bytes,) return (exif[Image.ExifTags.Base.ImageID].bytes,)
data = {"mode": obj.mode, "data": np.asarray(obj)} data = {"mode": obj.mode, "data": np.asarray(obj)}
if obj.palette is not None: palette = obj.palette
data["palette"] = obj.palette.palette if palette is not None:
if obj.palette.rawmode is not None: data["palette"] = palette.palette
data["palette_rawmode"] = obj.palette.rawmode if palette.rawmode is not None:
data["palette_rawmode"] = palette.rawmode
return cls.iter_item_to_bytes("image", data) return cls.iter_item_to_bytes("image", data)
if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
exif = obj.media.getexif()
if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID
):
return (exif[Image.ExifTags.Base.ImageID].bytes,)
return cls.iter_item_to_bytes("image", obj.original_bytes)
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
tensor_obj: torch.Tensor = obj.cpu() tensor_obj: torch.Tensor = obj.cpu()
tensor_dtype = tensor_obj.dtype tensor_dtype = tensor_obj.dtype
......
...@@ -8,7 +8,7 @@ import pybase64 ...@@ -8,7 +8,7 @@ import pybase64
import torch import torch
from PIL import Image from PIL import Image
from .base import MediaIO from .base import MediaIO, MediaWithBytes
def rescale_image_size( def rescale_image_size(
...@@ -74,8 +74,12 @@ class ImageMediaIO(MediaIO[Image.Image]): ...@@ -74,8 +74,12 @@ class ImageMediaIO(MediaIO[Image.Image]):
) )
self.rgba_background_color = rgba_bg self.rgba_background_color = rgba_bg
def _convert_image_mode(self, image: Image.Image) -> Image.Image: def _convert_image_mode(
self, image: Image.Image | MediaWithBytes[Image.Image]
) -> Image.Image:
"""Convert image mode with custom background color.""" """Convert image mode with custom background color."""
if isinstance(image, MediaWithBytes):
image = image.media
if image.mode == self.image_mode: if image.mode == self.image_mode:
return image return image
elif image.mode == "RGBA" and self.image_mode == "RGB": elif image.mode == "RGBA" and self.image_mode == "RGB":
...@@ -83,18 +87,18 @@ class ImageMediaIO(MediaIO[Image.Image]): ...@@ -83,18 +87,18 @@ class ImageMediaIO(MediaIO[Image.Image]):
else: else:
return convert_image_mode(image, self.image_mode) return convert_image_mode(image, self.image_mode)
def load_bytes(self, data: bytes) -> Image.Image: def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
image = Image.open(BytesIO(data)) image = Image.open(BytesIO(data))
image.load() return MediaWithBytes(self._convert_image_mode(image), data)
return self._convert_image_mode(image)
def load_base64(self, media_type: str, data: str) -> Image.Image: def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
return self.load_bytes(pybase64.b64decode(data, validate=True)) return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> Image.Image: def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
image = Image.open(filepath) with open(filepath, "rb") as f:
image.load() data = f.read()
return self._convert_image_mode(image) image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
def encode_base64( def encode_base64(
self, self,
......
...@@ -32,6 +32,7 @@ if TYPE_CHECKING: ...@@ -32,6 +32,7 @@ if TYPE_CHECKING:
from PIL.Image import Image from PIL.Image import Image
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from .base import MediaWithBytes
from .processing import MultiModalHashes from .processing import MultiModalHashes
else: else:
...@@ -59,7 +60,7 @@ Represents a single audio ...@@ -59,7 +60,7 @@ Represents a single audio
item, which can be passed to a HuggingFace `AudioProcessor`. item, which can be passed to a HuggingFace `AudioProcessor`.
""" """
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"] ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
""" """
A `transformers.image_utils.ImageInput` representing a single image A `transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`. item, which can be passed to a HuggingFace `ImageProcessor`.
...@@ -174,6 +175,31 @@ class PlaceholderRange: ...@@ -174,6 +175,31 @@ class PlaceholderRange:
return int(self.is_embed.sum().item()) return int(self.is_embed.sum().item())
def extract_embeds_range(self) -> list[tuple[int, int]]:
"""Extract the start and end indices of the embedded region in prompt.
For example, given `PlaceholderRange(offset=2, length=5)` and
`is_embed = [False, True, False, True, True]`, the output is
`[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.
Returns:
A tuple `(start, end)` representing the start and end
indices (inclusive) of the embedded region.
Returns full placeholder range if `is_embed` is `None`.
"""
if self.is_embed is None:
return [(self.offset, self.offset + self.length)]
mask_i = self.is_embed.int()
starts = torch.nonzero(
torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
).flatten()
ends = torch.nonzero(
torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
).flatten()
ranges = torch.stack((starts, ends), dim=1) + self.offset
return [tuple(x) for x in ranges.tolist()]
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
...@@ -200,8 +226,10 @@ Uses a list instead of a tensor if the dimensions of each element do not match. ...@@ -200,8 +226,10 @@ Uses a list instead of a tensor if the dimensions of each element do not match.
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
"""Equality check between """
[`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.""" Equality check between
[`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
"""
if isinstance(a, torch.Tensor): if isinstance(a, torch.Tensor):
return isinstance(b, torch.Tensor) and torch.equal(a, b) return isinstance(b, torch.Tensor) and torch.equal(a, b)
elif isinstance(b, torch.Tensor): elif isinstance(b, torch.Tensor):
...@@ -220,13 +248,44 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: ...@@ -220,13 +248,44 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
return a == b return a == b
def _nested_tensors_h2d(
tensors: NestedTensors,
device: torch.types.Device,
) -> NestedTensors:
if device is None:
return tensors
return json_map_leaves(
(
lambda x: x.to(device=device, non_blocking=True)
if isinstance(x, torch.Tensor)
else x
),
tensors,
)
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
""" """
A dictionary containing nested tensors which have been batched via A dictionary containing nested tensors which have been batched via
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch]. [`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
""" """
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
"""
Equality check between
[`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
"""
for k in a:
if k not in b:
return False
if not nested_tensors_equal(a[k], b[k]):
return False
return True
@dataclass @dataclass
class MultiModalFeatureSpec: class MultiModalFeatureSpec:
""" """
...@@ -317,7 +376,7 @@ class MultiModalFieldElem: ...@@ -317,7 +376,7 @@ class MultiModalFieldElem:
) # noqa: E721 ) # noqa: E721
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class BaseMultiModalField(ABC): class BaseMultiModalField(ABC):
""" """
Defines how to interpret tensor data belonging to a keyword argument in Defines how to interpret tensor data belonging to a keyword argument in
...@@ -325,6 +384,12 @@ class BaseMultiModalField(ABC): ...@@ -325,6 +384,12 @@ class BaseMultiModalField(ABC):
multi-modal items, and vice versa. multi-modal items, and vice versa.
""" """
keep_on_cpu: bool = False
"""
If `True`, then this field is excluded from being moved to the accelerator
when `MultiModalKwargsItems.get_data()` is called to batch the data.
"""
def _field_factory(self, *, modality: str, key: str): def _field_factory(self, *, modality: str, key: str):
f = partial( f = partial(
MultiModalFieldElem, MultiModalFieldElem,
...@@ -369,6 +434,7 @@ class BaseMultiModalField(ABC): ...@@ -369,6 +434,7 @@ class BaseMultiModalField(ABC):
self, self,
elems: list[MultiModalFieldElem], elems: list[MultiModalFieldElem],
*, *,
device: torch.types.Device = None,
pin_memory: bool = False, pin_memory: bool = False,
) -> NestedTensors: ) -> NestedTensors:
""" """
...@@ -382,11 +448,17 @@ class BaseMultiModalField(ABC): ...@@ -382,11 +448,17 @@ class BaseMultiModalField(ABC):
if len(set(field_types)) > 1: if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}") raise ValueError(f"Cannot merge different {field_types=}")
if device is not None and self.keep_on_cpu:
device = "cpu"
if pin_memory and self.keep_on_cpu:
pin_memory = False
batch = [elem.data for elem in elems] batch = [elem.data for elem in elems]
return self._reduce_data(batch, pin_memory=pin_memory) out = self._reduce_data(batch, pin_memory=pin_memory)
return _nested_tensors_h2d(out, device=device)
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class MultiModalBatchedField(BaseMultiModalField): class MultiModalBatchedField(BaseMultiModalField):
""" """
Info: Info:
...@@ -428,7 +500,7 @@ class MultiModalBatchedField(BaseMultiModalField): ...@@ -428,7 +500,7 @@ class MultiModalBatchedField(BaseMultiModalField):
return batch return batch
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class MultiModalFlatField(BaseMultiModalField): class MultiModalFlatField(BaseMultiModalField):
""" """
Info: Info:
...@@ -488,7 +560,7 @@ class MultiModalFlatField(BaseMultiModalField): ...@@ -488,7 +560,7 @@ class MultiModalFlatField(BaseMultiModalField):
return [e for elem in batch for e in elem] return [e for elem in batch for e in elem]
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class MultiModalSharedField(BaseMultiModalField): class MultiModalSharedField(BaseMultiModalField):
""" """
Info: Info:
...@@ -515,9 +587,10 @@ class MultiModalSharedField(BaseMultiModalField): ...@@ -515,9 +587,10 @@ class MultiModalSharedField(BaseMultiModalField):
return batch[0] return batch[0]
@dataclass(frozen=True)
class MultiModalFieldConfig: class MultiModalFieldConfig:
@staticmethod @staticmethod
def batched(modality: str): def batched(modality: str, *, keep_on_cpu: bool = False):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data. indexing into the first dimension of the underlying data.
...@@ -525,6 +598,7 @@ class MultiModalFieldConfig: ...@@ -525,6 +598,7 @@ class MultiModalFieldConfig:
Args: Args:
modality: The modality of the multi-modal item that uses this modality: The modality of the multi-modal item that uses this
keyword argument. keyword argument.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example: Example:
...@@ -541,7 +615,7 @@ class MultiModalFieldConfig: ...@@ -541,7 +615,7 @@ class MultiModalFieldConfig:
``` ```
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalBatchedField(), field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
modality=modality, modality=modality,
) )
...@@ -550,6 +624,8 @@ class MultiModalFieldConfig: ...@@ -550,6 +624,8 @@ class MultiModalFieldConfig:
modality: str, modality: str,
slices: Sequence[slice] | Sequence[Sequence[slice]], slices: Sequence[slice] | Sequence[Sequence[slice]],
dim: int = 0, dim: int = 0,
*,
keep_on_cpu: bool = False,
): ):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
...@@ -562,6 +638,7 @@ class MultiModalFieldConfig: ...@@ -562,6 +638,7 @@ class MultiModalFieldConfig:
slices (dim>0) that is used to extract the data corresponding slices (dim>0) that is used to extract the data corresponding
to it. to it.
dim: The dimension to extract data, default to 0. dim: The dimension to extract data, default to 0.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example: Example:
...@@ -596,12 +673,22 @@ class MultiModalFieldConfig: ...@@ -596,12 +673,22 @@ class MultiModalFieldConfig:
``` ```
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalFlatField(slices=slices, dim=dim), field=MultiModalFlatField(
slices=slices,
dim=dim,
keep_on_cpu=keep_on_cpu,
),
modality=modality, modality=modality,
) )
@staticmethod @staticmethod
def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0): def flat_from_sizes(
modality: str,
size_per_item: "torch.Tensor",
dim: int = 0,
*,
keep_on_cpu: bool = False,
):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data. slicing along the first dimension of the underlying data.
...@@ -612,6 +699,7 @@ class MultiModalFieldConfig: ...@@ -612,6 +699,7 @@ class MultiModalFieldConfig:
size_per_item: For each multi-modal item, the size of the slice size_per_item: For each multi-modal item, the size of the slice
that is used to extract the data corresponding to it. that is used to extract the data corresponding to it.
dim: The dimension to slice, default to 0. dim: The dimension to slice, default to 0.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example: Example:
...@@ -659,10 +747,20 @@ class MultiModalFieldConfig: ...@@ -659,10 +747,20 @@ class MultiModalFieldConfig:
for i in range(len(size_per_item)) for i in range(len(size_per_item))
] ]
return MultiModalFieldConfig.flat(modality, slices, dim=dim) return MultiModalFieldConfig.flat(
modality,
slices,
dim=dim,
keep_on_cpu=keep_on_cpu,
)
@staticmethod @staticmethod
def shared(modality: str, batch_size: int): def shared(
modality: str,
batch_size: int,
*,
keep_on_cpu: bool = False,
):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
taking the entirety of the underlying data. taking the entirety of the underlying data.
...@@ -673,6 +771,7 @@ class MultiModalFieldConfig: ...@@ -673,6 +771,7 @@ class MultiModalFieldConfig:
modality: The modality of the multi-modal item that uses this modality: The modality of the multi-modal item that uses this
keyword argument. keyword argument.
batch_size: The number of multi-modal items which share this data. batch_size: The number of multi-modal items which share this data.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example: Example:
...@@ -691,18 +790,15 @@ class MultiModalFieldConfig: ...@@ -691,18 +790,15 @@ class MultiModalFieldConfig:
``` ```
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size), field=MultiModalSharedField(
batch_size=batch_size,
keep_on_cpu=keep_on_cpu,
),
modality=modality, modality=modality,
) )
def __init__(self, field: BaseMultiModalField, modality: str) -> None: field: BaseMultiModalField
super().__init__() modality: str
self.field = field
self.modality = modality
def __repr__(self) -> str:
return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"
def build_elems( def build_elems(
self, self,
...@@ -727,7 +823,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): ...@@ -727,7 +823,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
modality=modality, modality=modality,
key="dummy", key="dummy",
data=torch.empty(nbytes, dtype=torch.uint8), data=torch.empty(nbytes, dtype=torch.uint8),
field=MultiModalSharedField(1), field=MultiModalSharedField(batch_size=1),
) )
return MultiModalKwargsItem.from_elems([mm_elem]) return MultiModalKwargsItem.from_elems([mm_elem])
...@@ -822,7 +918,13 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): ...@@ -822,7 +918,13 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
return self # type: ignore[return-value] return self # type: ignore[return-value]
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": def get_data(
self,
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> BatchedTensorInputs:
"""Construct a dictionary of keyword arguments to pass to the model."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for modality, items in self.items(): for modality, items in self.items():
for i, item in enumerate(items): for i, item in enumerate(items):
...@@ -834,12 +936,16 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): ...@@ -834,12 +936,16 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
for key, elem in item.items(): for key, elem in item.items():
elems_by_key[key].append(elem) elems_by_key[key].append(elem)
return MultiModalKwargs( data = {
{ key: elems[0].field.reduce_data(
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) elems,
for key, elems in elems_by_key.items() device=device,
} pin_memory=pin_memory,
) )
for key, elems in elems_by_key.items()
}
return data
MultiModalKwargsOptionalItems: TypeAlias = ( MultiModalKwargsOptionalItems: TypeAlias = (
...@@ -848,6 +954,7 @@ MultiModalKwargsOptionalItems: TypeAlias = ( ...@@ -848,6 +954,7 @@ MultiModalKwargsOptionalItems: TypeAlias = (
) )
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.13.")
class MultiModalKwargs(UserDict[str, NestedTensors]): class MultiModalKwargs(UserDict[str, NestedTensors]):
""" """
A dictionary that represents the keyword arguments to A dictionary that represents the keyword arguments to
...@@ -881,91 +988,6 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -881,91 +988,6 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
): ):
return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory) return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
@staticmethod
def _try_stack(
nested_tensors: NestedTensors, pin_memory: bool = False
) -> NestedTensors:
"""
Stack the inner dimensions that have the same shape in
a nested list of tensors.
Thus, a dimension represented by a list means that the inner
dimensions are different for each element along that dimension.
"""
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
# TODO: Remove these once all models have been migrated
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
tensors_ = cast(list[torch.Tensor], stacked)
if len(tensors_) == 1:
# An optimization when `tensors_` contains only one tensor:
# - produce exactly same result as `torch.stack(tensors_)`
# - will achieve zero-copy if the tensor is contiguous
return tensors_[0].unsqueeze(0).contiguous()
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
outputs = torch.empty(
len(tensors_),
*tensors_[0].shape,
dtype=tensors_[0].dtype,
device=tensors_[0].device,
pin_memory=pin_memory,
)
return torch.stack(tensors_, out=outputs)
@staticmethod
def batch(
inputs_list: list["MultiModalKwargs"], pin_memory: bool = False
) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if len(inputs_list) == 0:
return {}
# We need to consider the case where each item in the batch
# contains different modalities (i.e. different keys).
item_lists = defaultdict[str, list[NestedTensors]](list)
for inputs in inputs_list:
for k, v in inputs.items():
item_lists[k].append(v)
return {
k: MultiModalKwargs._try_stack(item_list, pin_memory)
for k, item_list in item_lists.items()
}
@staticmethod
def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
return json_map_leaves(
lambda x: x.to(device=device, non_blocking=True),
batched_inputs,
)
def __getitem__(self, key: str): def __getitem__(self, key: str):
if key not in self: if key not in self:
raise KeyError( raise KeyError(
......
...@@ -23,6 +23,7 @@ from vllm.utils.collection_utils import is_list_of ...@@ -23,6 +23,7 @@ from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from .audio import AudioResampler from .audio import AudioResampler
from .base import MediaWithBytes
from .inputs import ( from .inputs import (
AudioItem, AudioItem,
HfAudioItem, HfAudioItem,
...@@ -84,6 +85,12 @@ class ModalityDataItems(ABC, Generic[_T, _I]): ...@@ -84,6 +85,12 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
"""Get all data items.""" """Get all data items."""
return [self.get(idx) for idx in range(self.get_count())] return [self.get(idx) for idx in range(self.get_count())]
def get_item_for_hash(self, index: int) -> object:
return self.get(index)
def get_all_items_for_hash(self) -> list[object]:
return [self.get_item_for_hash(idx) for idx in range(self.get_count())]
@abstractmethod @abstractmethod
def get_processor_data(self) -> Mapping[str, object]: def get_processor_data(self) -> Mapping[str, object]:
"""Get the data to pass to the HF processor.""" """Get the data to pass to the HF processor."""
...@@ -98,10 +105,18 @@ class ModalityDataItems(ABC, Generic[_T, _I]): ...@@ -98,10 +105,18 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
"""Base class for data items that are arranged in a list.""" """Base class for data items that are arranged in a list."""
def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
"""Extract media from wrapper if present."""
return item.media if isinstance(item, MediaWithBytes) else item
def get_count(self) -> int: def get_count(self) -> int:
return len(self.data) return len(self.data)
def get(self, index: int) -> _T: def get(self, index: int) -> _T:
return self._unwrap(self.data[index])
def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
# Return raw item for hashing (preserves original_bytes if present)
return self.data[index] return self.data[index]
def get_processor_data(self) -> Mapping[str, object]: def get_processor_data(self) -> Mapping[str, object]:
...@@ -119,11 +134,17 @@ class EmbeddingItems( ...@@ -119,11 +134,17 @@ class EmbeddingItems(
or a list of embedding tensors (one per item). or a list of embedding tensors (one per item).
""" """
def _unwrap(
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
) -> torch.Tensor:
"""Extract media from wrapper if present."""
return item.media if isinstance(item, MediaWithBytes) else item
def get_count(self) -> int: def get_count(self) -> int:
return len(self.data) return len(self.data)
def get(self, index: int) -> torch.Tensor: def get(self, index: int) -> torch.Tensor:
return self.data[index] return self._unwrap(self.data[index])
def get_processor_data(self) -> Mapping[str, object]: def get_processor_data(self) -> Mapping[str, object]:
return {} return {}
...@@ -463,7 +484,7 @@ class MultiModalDataParser: ...@@ -463,7 +484,7 @@ class MultiModalDataParser:
return ImageEmbeddingItems(data) return ImageEmbeddingItems(data)
if ( if (
isinstance(data, PILImage.Image) isinstance(data, (PILImage.Image, MediaWithBytes))
or isinstance(data, (np.ndarray, torch.Tensor)) or isinstance(data, (np.ndarray, torch.Tensor))
and data.ndim == 3 and data.ndim == 3
): ):
......
...@@ -1684,7 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1684,7 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# For None entries, compute a hash; otherwise, use provided ID. # For None entries, compute a hash; otherwise, use provided ID.
computed: list[str] = [] computed: list[str] = []
for i, item in enumerate(items): for i, item in enumerate(items.get_all_items_for_hash()):
item_uuid = mm_uuids_per_modality[i] item_uuid = mm_uuids_per_modality[i]
# NOTE: Even if a item_uuid is provided, we still compute a # NOTE: Even if a item_uuid is provided, we still compute a
......
...@@ -19,7 +19,6 @@ from PIL import Image, UnidentifiedImageError ...@@ -19,7 +19,6 @@ from PIL import Image, UnidentifiedImageError
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.registry import ExtensionManager from vllm.utils.registry import ExtensionManager
from .audio import AudioEmbeddingMediaIO, AudioMediaIO from .audio import AudioEmbeddingMediaIO, AudioMediaIO
...@@ -67,8 +66,9 @@ class MediaConnector: ...@@ -67,8 +66,9 @@ class MediaConnector:
to set num_frames for video, set to set num_frames for video, set
`--media-io-kwargs '{"video":{"num_frames":40}}'` `--media-io-kwargs '{"video":{"num_frames":40}}'`
connection: HTTP connection client to download media contents. connection: HTTP connection client to download media contents.
allowed_local_media_path: A local directory to load media files allowed_local_media_path: A local directory to load media files from.
from. allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
""" """
super().__init__() super().__init__()
...@@ -123,16 +123,16 @@ class MediaConnector: ...@@ -123,16 +123,16 @@ class MediaConnector:
"Cannot load local files without `--allowed-local-media-path`." "Cannot load local files without `--allowed-local-media-path`."
) )
filepath = Path(url2pathname(url_spec.path)) filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
if allowed_local_media_path not in filepath.resolve().parents: if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError( raise ValueError(
f"The file path {filepath} must be a subpath " f"The file path {filepath} must be a subpath "
f"of `--allowed-local-media-path` {allowed_local_media_path}." f"of `--allowed-local-media-path {allowed_local_media_path}`."
) )
return media_io.load_file(filepath) return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec) -> None: def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
if ( if (
self.allowed_media_domains self.allowed_media_domains
and url_spec.hostname not in self.allowed_media_domains and url_spec.hostname not in self.allowed_media_domains
...@@ -413,7 +413,7 @@ def group_mm_kwargs_by_modality( ...@@ -413,7 +413,7 @@ def group_mm_kwargs_by_modality(
device: torch.types.Device = None, device: torch.types.Device = None,
pin_memory: bool = False, pin_memory: bool = False,
merge_by_field_config: bool | None = None, merge_by_field_config: bool | None = None,
multimodal_cpu_fields: Set[str] = frozenset(), multimodal_cpu_fields: Set[str] | None = None,
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]: ) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance. modality together into the same `MultiModalKwargs` instance.
...@@ -426,59 +426,28 @@ def group_mm_kwargs_by_modality( ...@@ -426,59 +426,28 @@ def group_mm_kwargs_by_modality(
Yields: Yields:
A tuple `(modality, num_items, grouped_kwargs)`. A tuple `(modality, num_items, grouped_kwargs)`.
""" """
if merge_by_field_config is None: if merge_by_field_config is not None:
raise RuntimeError( logger.warning_once(
"`group_mm_kwargs_by_modality` now requires " "The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
"`merge_by_field_config` arg, please update your model runner " "is deprecated and will be removed in v0.13."
"according to https://github.com/vllm-project/vllm/pull/25676."
) )
if merge_by_field_config is False: if multimodal_cpu_fields is not None:
logger.warning_once( logger.warning_once(
"The legacy code for batching multi-modal kwargs is deprecated and " "The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
"will be removed in v0.12. Please update your model with " "is deprecated and will be removed in v0.13."
"`merge_by_field_config=True` to use the new code defined by "
"`MultiModalFieldConfig`. You can refer to "
"https://github.com/vllm-project/vllm/issues/26149 "
"for some examples on how to do this."
) )
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalKwargsItems
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items) items_lst = list(items)
mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
mm_kwargs_data = mm_kwargs_items.get_data(
device=device,
pin_memory=pin_memory,
)
if merge_by_field_config: yield modality, len(items_lst), mm_kwargs_data
mm_kwargs_group: BatchedTensorInputs = dict(
MultiModalKwargsItems.from_seq(items_lst).get_data(
pin_memory=pin_memory
)
)
if device is not None:
mm_kwargs_group = {
k: json_map_leaves(
lambda x: x.to(device=device, non_blocking=True)
if isinstance(x, torch.Tensor)
else x,
v,
)
if k not in multimodal_cpu_fields
else v
for k, v in mm_kwargs_group.items()
}
else:
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[
MultiModalKwargsItems.from_seq([item]).get_data()
for item in items_lst
],
pin_memory=pin_memory,
),
device=device,
)
yield modality, len(items_lst), mm_kwargs_group
def fetch_audio( def fetch_audio(
...@@ -489,9 +458,16 @@ def fetch_audio( ...@@ -489,9 +458,16 @@ def fetch_audio(
Args: Args:
audio_url: URL of the audio file to fetch. audio_url: URL of the audio file to fetch.
audio_io_kwargs: Additional kwargs passed to handle audio IO. audio_io_kwargs: Additional kwargs passed to handle audio IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
""" """
media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs} media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_audio(audio_url) return media_connector.fetch_audio(audio_url)
...@@ -503,9 +479,16 @@ def fetch_image( ...@@ -503,9 +479,16 @@ def fetch_image(
Args: Args:
image_url: URL of the image file to fetch. image_url: URL of the image file to fetch.
image_io_kwargs: Additional kwargs passed to handle image IO. image_io_kwargs: Additional kwargs passed to handle image IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
""" """
media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs} media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_image(image_url) return media_connector.fetch_image(image_url)
...@@ -517,7 +500,14 @@ def fetch_video( ...@@ -517,7 +500,14 @@ def fetch_video(
Args: Args:
video_url: URL of the video file to fetch. video_url: URL of the video file to fetch.
video_io_kwargs: Additional kwargs passed to handle video IO. video_io_kwargs: Additional kwargs passed to handle video IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
""" """
media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs} media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_video(video_url) return media_connector.fetch_video(video_url)
...@@ -267,7 +267,7 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend): ...@@ -267,7 +267,7 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
return frames, metadata return frames, metadata
class VideoMediaIO(MediaIO[npt.NDArray]): class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
def __init__( def __init__(
self, self,
image_io: ImageMediaIO, image_io: ImageMediaIO,
......
...@@ -10,6 +10,7 @@ import sys ...@@ -10,6 +10,7 @@ import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import psutil
import regex as re import regex as re
import torch import torch
...@@ -132,6 +133,7 @@ class CpuPlatform(Platform): ...@@ -132,6 +133,7 @@ class CpuPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
...@@ -147,11 +149,21 @@ class CpuPlatform(Platform): ...@@ -147,11 +149,21 @@ class CpuPlatform(Platform):
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
node_dir = "/sys/devices/system/node"
if kv_cache_space is None: if kv_cache_space is None:
kv_cache_space = 4 * GiB_bytes # type: ignore nodes = (
[d for d in os.listdir(node_dir) if d.startswith("node")]
if os.path.exists(node_dir)
else []
)
num_numa_nodes = len(nodes) or 1
free_cpu_memory = psutil.virtual_memory().total // num_numa_nodes
DEFAULT_CPU_MEM_UTILIZATION = 0.5
kv_cache_space = int(free_cpu_memory * DEFAULT_CPU_MEM_UTILIZATION)
kv_cache_space_gib = kv_cache_space / GiB_bytes
logger.warning_once( logger.warning_once(
"Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) " "VLLM_CPU_KVCACHE_SPACE not set. Using "
"for CPU backend is not set, using 4 by default." f"{kv_cache_space_gib:.2f} GiB for KV cache."
) )
else: else:
kv_cache_space *= GiB_bytes kv_cache_space *= GiB_bytes
......
...@@ -14,7 +14,6 @@ from typing_extensions import ParamSpec ...@@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -149,6 +148,8 @@ class CudaPlatformBase(Platform): ...@@ -149,6 +148,8 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.attention.backends.registry import AttentionBackendEnum
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
...@@ -171,7 +172,7 @@ class CudaPlatformBase(Platform): ...@@ -171,7 +172,7 @@ class CudaPlatformBase(Platform):
and cache_config.block_size is not None and cache_config.block_size is not None
): ):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs, # then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the # else we default to CutlassMLA. For each case, we force the
# required block_size. # required block_size.
...@@ -179,23 +180,25 @@ class CudaPlatformBase(Platform): ...@@ -179,23 +180,25 @@ class CudaPlatformBase(Platform):
use_cutlass_mla = False use_cutlass_mla = False
use_flashinfer_mla = False use_flashinfer_mla = False
if envs.VLLM_ATTENTION_BACKEND is None: if vllm_config.attention_config.backend is None:
# Default case # Default case
if cls.is_device_capability(100): if cls.is_device_capability(100) and not use_sparse:
# Blackwell => Force CutlassMLA. # Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2).
use_cutlass_mla = True use_cutlass_mla = True
# TODO: This does not work, because the # Set the backend in AttentionConfig so it's used during
# global_force_attn_backend_context_manager is not set. # backend selection
# See vllm/attention/selector.py:_cached_get_attn_backend vllm_config.attention_config.backend = (
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" AttentionBackendEnum.CUTLASS_MLA
)
else: else:
# Not Blackwell # Not Blackwell
use_flashmla = True use_flashmla = True
else: else:
# Forced case # Forced case
use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" backend = vllm_config.attention_config.backend
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.ops.flashmla import is_flashmla_dense_supported
...@@ -229,27 +232,20 @@ class CudaPlatformBase(Platform): ...@@ -229,27 +232,20 @@ class CudaPlatformBase(Platform):
logger.info( logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse backend." "Forcing kv cache block size to 64 for FlashMLASparse backend."
) )
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if ( if (
parallel_config.all2all_backend == "deepep_high_throughput" model_config is not None
and parallel_config.data_parallel_size > 1 and model_config.is_mm_prefix_lm
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE and scheduler_config.is_multimodal_model
and not scheduler_config.disable_chunked_mm_input
): ):
# TODO: Piecewise Cuda graph might be enabled logger.warning(
# if torch compile cache key issue fixed "Forcing --disable_chunked_mm_input for models "
# See https://github.com/vllm-project/vllm/pull/25093 "with multimodal-bidirectional attention."
logger.info(
"WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
"kernels are optimized for prefill and are incompatible with "
"CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, "
"use --all2all-backend with another option, such as "
"deepep_low_latency, pplx, or allgather_reducescatter."
) )
compilation_config.cudagraph_mode = CUDAGraphMode.NONE scheduler_config.disable_chunked_mm_input = True
@classmethod @classmethod
def get_current_memory_usage( def get_current_memory_usage(
...@@ -286,6 +282,7 @@ class CudaPlatformBase(Platform): ...@@ -286,6 +282,7 @@ class CudaPlatformBase(Platform):
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
use_mm_prefix,
device_capability, device_capability,
attn_type, attn_type,
) -> tuple[ ) -> tuple[
...@@ -307,6 +304,7 @@ class CudaPlatformBase(Platform): ...@@ -307,6 +304,7 @@ class CudaPlatformBase(Platform):
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
use_mm_prefix,
device_capability, device_capability,
attn_type, attn_type,
) )
...@@ -330,6 +328,7 @@ class CudaPlatformBase(Platform): ...@@ -330,6 +328,7 @@ class CudaPlatformBase(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
if attn_type is None: if attn_type is None:
...@@ -350,6 +349,7 @@ class CudaPlatformBase(Platform): ...@@ -350,6 +349,7 @@ class CudaPlatformBase(Platform):
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
use_mm_prefix,
device_capability, device_capability,
attn_type, attn_type,
) )
...@@ -374,6 +374,7 @@ class CudaPlatformBase(Platform): ...@@ -374,6 +374,7 @@ class CudaPlatformBase(Platform):
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
use_mm_prefix,
device_capability, device_capability,
attn_type, attn_type,
) )
......
...@@ -239,6 +239,7 @@ class Platform: ...@@ -239,6 +239,7 @@ class Platform:
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
......
...@@ -217,6 +217,7 @@ class RocmPlatform(Platform): ...@@ -217,6 +217,7 @@ class RocmPlatform(Platform):
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
use_mm_prefix,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
...@@ -382,8 +383,26 @@ class RocmPlatform(Platform): ...@@ -382,8 +383,26 @@ class RocmPlatform(Platform):
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
if parallel_config.decode_context_parallel_size > 1:
logger.warning_once(
"Decode context parallel (DCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs
elif parallel_config.prefill_context_parallel_size > 1:
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 16 cache_config.block_size = 16
...@@ -398,6 +417,9 @@ class RocmPlatform(Platform): ...@@ -398,6 +417,9 @@ class RocmPlatform(Platform):
): ):
compilation_config.custom_ops.append("+rms_norm") compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
@classmethod @classmethod
def verify_model_arch(cls, model_arch: str) -> None: def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS: if model_arch in _ROCM_UNSUPPORTED_MODELS:
......
...@@ -63,8 +63,9 @@ class TpuPlatform(Platform): ...@@ -63,8 +63,9 @@ class TpuPlatform(Platform):
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
block_size: int, block_size: int,
use_mla: bool, use_mla: bool,
has_sink, has_sink: bool,
use_sparse, use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
if use_sparse: if use_sparse:
......
...@@ -48,7 +48,8 @@ class XPUPlatform(Platform): ...@@ -48,7 +48,8 @@ class XPUPlatform(Platform):
block_size: int, block_size: int,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse, use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.attention.backends.utils import set_kv_cache_layout
......
...@@ -3,26 +3,27 @@ ...@@ -3,26 +3,27 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import nullcontext from contextlib import nullcontext
from typing import Literal
import torch import torch
from typing_extensions import override from typing_extensions import override
import vllm.envs as envs from vllm.config import ProfilerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
class WorkerProfiler(ABC): class WorkerProfiler(ABC):
def __init__(self) -> None: def __init__(self, profiler_config: ProfilerConfig) -> None:
self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS self._delay_iters = profiler_config.delay_iterations
if self._delay_iters > 0: if self._delay_iters > 0:
logger.info_once( logger.info_once(
"GPU profiling will start " "GPU profiling will start "
f"{self._delay_iters} steps after start_profile." f"{self._delay_iters} steps after start_profile."
) )
self._max_iters = envs.VLLM_PROFILER_MAX_ITERS self._max_iters = profiler_config.max_iterations
if self._max_iters > 0: if self._max_iters > 0:
logger.info_once( logger.info_once(
"GPU profiling will stop " "GPU profiling will stop "
...@@ -133,12 +134,27 @@ class WorkerProfiler(ABC): ...@@ -133,12 +134,27 @@ class WorkerProfiler(ABC):
return nullcontext() return nullcontext()
TorchProfilerActivity = Literal["CPU", "CUDA", "XPU"]
TorchProfilerActivityMap = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"CUDA": torch.profiler.ProfilerActivity.CUDA,
"XPU": torch.profiler.ProfilerActivity.XPU,
}
class TorchProfilerWrapper(WorkerProfiler): class TorchProfilerWrapper(WorkerProfiler):
def __init__(self, worker_name: str, local_rank: int) -> None: def __init__(
super().__init__() self,
profiler_config: ProfilerConfig,
worker_name: str,
local_rank: int,
activities: list[TorchProfilerActivity],
) -> None:
super().__init__(profiler_config)
self.local_rank = local_rank self.local_rank = local_rank
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR self.profiler_config = profiler_config
torch_profiler_trace_dir = profiler_config.torch_profiler_dir
if local_rank in (None, 0): if local_rank in (None, 0):
logger.info( logger.info(
"Torch profiling enabled. Traces will be saved to: %s", "Torch profiling enabled. Traces will be saved to: %s",
...@@ -147,24 +163,23 @@ class TorchProfilerWrapper(WorkerProfiler): ...@@ -147,24 +163,23 @@ class TorchProfilerWrapper(WorkerProfiler):
logger.debug( logger.debug(
"Profiler config: record_shapes=%s," "Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s", "profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, profiler_config.torch_profiler_record_shapes,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, profiler_config.torch_profiler_with_memory,
envs.VLLM_TORCH_PROFILER_WITH_STACK, profiler_config.torch_profiler_with_stack,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS, profiler_config.torch_profiler_with_flops,
) )
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
self.profiler = torch.profiler.profile( self.profiler = torch.profiler.profile(
activities=[ activities=[TorchProfilerActivityMap[activity] for activity in activities],
torch.profiler.ProfilerActivity.CPU, record_shapes=profiler_config.torch_profiler_record_shapes,
torch.profiler.ProfilerActivity.CUDA, profile_memory=profiler_config.torch_profiler_with_memory,
], with_stack=profiler_config.torch_profiler_with_stack,
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, with_flops=profiler_config.torch_profiler_with_flops,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, torch_profiler_trace_dir,
worker_name=worker_name, worker_name=worker_name,
use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, use_gzip=profiler_config.torch_profiler_use_gzip,
), ),
) )
...@@ -176,9 +191,10 @@ class TorchProfilerWrapper(WorkerProfiler): ...@@ -176,9 +191,10 @@ class TorchProfilerWrapper(WorkerProfiler):
def _stop(self) -> None: def _stop(self) -> None:
self.profiler.stop() self.profiler.stop()
if envs.VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: profiler_config = self.profiler_config
rank = self.local_rank rank = self.local_rank
profiler_dir = envs.VLLM_TORCH_PROFILER_DIR if profiler_config.torch_profiler_dump_cuda_time_total:
profiler_dir = profiler_config.torch_profiler_dir
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
sort_key = "self_cuda_time_total" sort_key = "self_cuda_time_total"
table = self.profiler.key_averages().table(sort_by=sort_key) table = self.profiler.key_averages().table(sort_by=sort_key)
...@@ -189,6 +205,12 @@ class TorchProfilerWrapper(WorkerProfiler): ...@@ -189,6 +205,12 @@ class TorchProfilerWrapper(WorkerProfiler):
# only print profiler results on rank 0 # only print profiler results on rank 0
if rank == 0: if rank == 0:
print(table) print(table)
if self.dump_cpu_time_total and rank == 0:
logger.info(
self.profiler.key_averages().table(
sort_by="self_cpu_time_total", row_limit=50
)
)
@override @override
def annotate_context_manager(self, name: str): def annotate_context_manager(self, name: str):
...@@ -196,8 +218,8 @@ class TorchProfilerWrapper(WorkerProfiler): ...@@ -196,8 +218,8 @@ class TorchProfilerWrapper(WorkerProfiler):
class CudaProfilerWrapper(WorkerProfiler): class CudaProfilerWrapper(WorkerProfiler):
def __init__(self) -> None: def __init__(self, profiler_config: ProfilerConfig) -> None:
super().__init__() super().__init__(profiler_config)
# Note: lazy import to avoid dependency issues if CUDA is not available. # Note: lazy import to avoid dependency issues if CUDA is not available.
import torch.cuda.profiler as cuda_profiler import torch.cuda.profiler as cuda_profiler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment