Unverified Commit 766bc816 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Store only the keys for multi-modal data in P0 (#22198)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 289b18e6
...@@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", ...@@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
If you run out of CPU RAM, try the following options: If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). - (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process + 4 GiB per engine core process)
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
## Multi-modal input limits ## Multi-modal input limits
...@@ -129,20 +129,18 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory. ...@@ -129,20 +129,18 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory.
Here are some examples: Here are some examples:
??? code ```python
from vllm import LLM
```python
from vllm import LLM
# Available for Qwen2-VL series models # Available for Qwen2-VL series models
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_kwargs={ mm_processor_kwargs={
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28 "max_pixels": 768 * 768, # Default is 1280 * 28 * 28
}) })
# Available for InternVL series models # Available for InternVL series models
llm = LLM(model="OpenGVLab/InternVL2-2B", llm = LLM(model="OpenGVLab/InternVL2-2B",
mm_processor_kwargs={ mm_processor_kwargs={
"max_dynamic_patch": 4, # Default is 12 "max_dynamic_patch": 4, # Default is 12
}) })
``` ```
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
This guide covers optimization strategies and performance tuning for vLLM V1. This guide covers optimization strategies and performance tuning for vLLM V1.
!!! tip
Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory.
## Preemption ## Preemption
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
...@@ -126,62 +129,44 @@ Data parallelism replicates the entire model across multiple GPU sets and proces ...@@ -126,62 +129,44 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
## Reducing Memory Usage ## Input Processing
If you encounter out-of-memory issues, consider these strategies:
### Context Length and Batch Size ### Parallel Processing
You can reduce memory usage by limiting the context length and batch size: You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing).
This is useful when input processing (which is run inside the API server)
becomes a bottleneck compared to model execution (which is run inside engine core)
and you have excess CPU capacity.
```python ```console
from vllm import LLM # Run 4 API processes and 1 engine core process
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4
llm = LLM( # Run 4 API processes and 2 engine core processes
model="meta-llama/Llama-3.1-8B-Instruct", vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
max_model_len=2048, # Limit context window
max_num_seqs=4 # Limit batch size
)
``` ```
### Adjust CUDA Graph Compilation !!! note
API server scale-out is only available for online inference.
CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level: !!! note
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
```python because it requires a one-to-one correspondance between API and engine core processes.
from vllm import LLM
from vllm.config import CompilationConfig, CompilationLevel
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes
)
)
```
Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`: ## Multi-Modal Caching
```python ### Processor Cache
from vllm import LLM
llm = LLM( By default, the multi-modal processor cache is enabled to avoid repeatedly processing
model="meta-llama/Llama-3.1-8B-Instruct", the same multi-modal inputs via Hugging Face `AutoProcessor`,
enforce_eager=True # Disable CUDA graph compilation which commonly occurs in multi-turn conversations.
)
```
### Multimodal Models You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable
(default 4 GiB per API process + 4 GiB per engine core process).
For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request: If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`:
```python ```python
from vllm import LLM llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
disable_mm_preprocessor_cache=True)
# Accept up to 2 images per prompt
llm = LLM(
model="Qwen/Qwen2.5-VL-3B-Instruct",
limit_mm_per_prompt={"image": 2}
)
``` ```
...@@ -166,7 +166,7 @@ def parse_args(): ...@@ -166,7 +166,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--disable-mm-preprocessor-cache", "--disable-mm-preprocessor-cache",
action="store_true", action="store_true",
help="If True, disables caching of multi-modal preprocessor/mapper.", help="If True, disables caching of multi-modal processor.",
) )
return parser.parse_args() return parser.parse_args()
......
...@@ -1565,7 +1565,7 @@ def parse_args(): ...@@ -1565,7 +1565,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--disable-mm-preprocessor-cache", "--disable-mm-preprocessor-cache",
action="store_true", action="store_true",
help="If True, disables caching of multi-modal preprocessor/mapper.", help="If True, disables caching of multi-modal processor.",
) )
parser.add_argument( parser.add_argument(
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig, RunnerOption from vllm.config import ModelConfig, ModelDType, RunnerOption
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
...@@ -257,7 +257,7 @@ def check_logprobs_close( ...@@ -257,7 +257,7 @@ def check_logprobs_close(
def build_model_context( def build_model_context(
model_id: str, model_id: str,
runner: RunnerOption = "auto", runner: RunnerOption = "auto",
dtype: Union[str, torch.dtype] = "auto", dtype: ModelDType = "auto",
model_config_kwargs: Optional[dict[str, Any]] = None, model_config_kwargs: Optional[dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None,
...@@ -279,6 +279,7 @@ def build_model_context( ...@@ -279,6 +279,7 @@ def build_model_context(
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
model_config_kwargs = model_config_kwargs or {} model_config_kwargs = model_config_kwargs or {}
limit_mm_per_prompt = limit_mm_per_prompt or {}
model_config = ModelConfig( model_config = ModelConfig(
model_id, model_id,
runner=runner, runner=runner,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField)
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs.from_items([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])
# yapf: disable
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
cache = MultiModalCache.get_lru_cache(2048, type(item))
cache[""] = item
assert cache.currsize == expected_size
cache[""] = MultiModalCacheItemMetadata.wraps(item)
assert cache.currsize == expected_size
...@@ -6,20 +6,15 @@ from typing import Optional, cast ...@@ -6,20 +6,15 @@ from typing import Optional, cast
import numpy as np import numpy as np
import pytest import pytest
import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo, from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
ProcessingCache, PromptIndexTargets, PromptIndexTargets, PromptInsertion,
PromptInsertion, PromptReplacement, PromptReplacement, apply_text_matches,
apply_text_matches,
apply_token_matches, apply_token_matches,
find_mm_placeholders, find_mm_placeholders,
find_text_matches, find_token_matches, find_text_matches, find_token_matches,
...@@ -902,45 +897,6 @@ def test_find_mm_placeholders( ...@@ -902,45 +897,6 @@ def test_find_mm_placeholders(
assert result == expected assert result == expected
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs.from_items([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])
# yapf: disable
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
cache = ProcessingCache.get_lru_cache(2048, type(item))
cache[""] = item
assert cache.currsize == expected_size
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("limit", "num_supported", "is_valid"), ("limit", "num_supported", "is_valid"),
......
...@@ -444,8 +444,7 @@ class ModelConfig: ...@@ -444,8 +444,7 @@ class ModelConfig:
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
""" """
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not """If `True`, disable caching of the multi-modal processor."""
recommended)."""
override_neuron_config: dict[str, Any] = field(default_factory=dict) override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config """Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to that are specific to Neuron devices, this argument will be used to
...@@ -1692,6 +1691,31 @@ class ModelConfig: ...@@ -1692,6 +1691,31 @@ class ModelConfig:
def is_multimodal_model(self) -> bool: def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None return self.multimodal_config is not None
@property
def processor_return_mm_hashes(self) -> bool:
"""Whether the multi-modal processor should output hashes."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return not mm_config.disable_mm_preprocessor_cache
@property
def enable_mm_input_cache(self) -> bool:
"""Whether the multi-modal input cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return not mm_config.disable_mm_preprocessor_cache
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:
return 0
return envs.VLLM_MM_INPUT_CACHE_GIB
@property @property
def is_cross_encoder(self) -> bool: def is_cross_encoder(self) -> bool:
return (self._model_info.supports_cross_encoding return (self._model_info.supports_cross_encoding
...@@ -3369,7 +3393,7 @@ class MultiModalConfig: ...@@ -3369,7 +3393,7 @@ class MultiModalConfig:
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
""" """
If `True`, disable caching of the processed multi-modal inputs. If `True`, disable caching of the multi-modal processor.
""" """
interleave_mm_strings: bool = False interleave_mm_strings: bool = False
......
...@@ -1230,16 +1230,16 @@ class EngineArgs: ...@@ -1230,16 +1230,16 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel, enable_multimodal_encoder_data_parallel,
) )
supports_mm_preprocessor_cache = (self.data_parallel_size == 1 if model_config.is_multimodal_model:
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
or data_parallel_external_lb) or data_parallel_external_lb)
if (not supports_mm_preprocessor_cache if (not dp_supports_mm_processor_cache
and model_config.is_multimodal_model
and not model_config.disable_mm_preprocessor_cache): and not model_config.disable_mm_preprocessor_cache):
logger.warning( logger.warning(
"Multi-modal preprocessor cache is not compatible " "Multi-modal processor cache is disabled because "
"with data parallelism when there does not exist a " "it is not compatible with data parallelism when "
"one-to-one correspondance between API process and " "there does not exist a one-to-one correspondance "
"EngineCore process, so the cache will be disabled.") "between API and engine core processes.")
model_config.set_disable_mm_preprocessor_cache(True) model_config.set_disable_mm_preprocessor_cache(True)
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
......
...@@ -163,9 +163,8 @@ def run_multi_api_server(args: argparse.Namespace): ...@@ -163,9 +163,8 @@ def run_multi_api_server(args: argparse.Namespace):
if model_config.is_multimodal_model and not ( if model_config.is_multimodal_model and not (
orig_disable_mm_preprocessor_cache): orig_disable_mm_preprocessor_cache):
logger.warning( logger.warning("Multi-modal processor cache is disabled because "
"Multi-modal preprocessor cache is not compatible " "it is not compatible with `api_server_count > 1`.")
"with api_server_count > 1, so the cache will be disabled.")
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats log_stats = not engine_args.disable_log_stats
......
...@@ -65,7 +65,7 @@ if TYPE_CHECKING: ...@@ -65,7 +65,7 @@ if TYPE_CHECKING:
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MM_INPUT_CACHE_GIB: int = 8 VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
...@@ -561,8 +561,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -561,8 +561,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_VIDEO_LOADER_BACKEND": "VLLM_VIDEO_LOADER_BACKEND":
lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"),
# Cache size (in GiB) for multimodal input cache # Cache size (in GiB per process) for multimodal input cache
# Default is 4 GiB # Default is 4 GiB per API process + 4 GiB per engine core process
"VLLM_MM_INPUT_CACHE_GIB": "VLLM_MM_INPUT_CACHE_GIB":
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TypeVar, Union
import torch
from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors
logger = init_logger(__name__)
@dataclass
class MultiModalCacheItemMetadata:
size: int
@classmethod
def wraps(cls, value: "MultiModalCacheValue"):
return cls(size=MultiModalCache.get_item_size(value))
MultiModalCacheValue = Union[
MultiModalKwargs,
MultiModalKwargsItem,
Mapping[str, NestedTensors],
MultiModalCacheItemMetadata,
]
_V = TypeVar("_V", bound=MultiModalCacheValue)
class MultiModalCache:
@classmethod
def get_leaf_size(
cls,
leaf: object,
*,
debug: bool = False,
) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data, debug=debug)
# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return cls.get_item_size(leaf_data, debug=debug)
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
if isinstance(leaf, MultiModalCacheItemMetadata):
return leaf.size
return sys.getsizeof(leaf)
@classmethod
def get_item_size(
cls,
value: MultiModalCacheValue,
*,
debug: bool = False,
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
value),
)
if debug:
logger.debug("Calculated size of %s to be %.2f GiB", type(value),
size / GiB_bytes)
return size
@classmethod
def get_lru_cache(
cls,
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
return LRUCache(
GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
...@@ -16,16 +15,16 @@ import torch ...@@ -16,16 +15,16 @@ import torch
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from .cache import MultiModalCache
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, NestedTensors, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
...@@ -888,9 +887,6 @@ def find_mm_placeholders( ...@@ -888,9 +887,6 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")
class ProcessingCacheOptionalItem(NamedTuple): class ProcessingCacheOptionalItem(NamedTuple):
key: str key: str
value: Optional[MultiModalKwargsItem] value: Optional[MultiModalKwargsItem]
...@@ -901,48 +897,7 @@ class ProcessingCacheItem(NamedTuple): ...@@ -901,48 +897,7 @@ class ProcessingCacheItem(NamedTuple):
value: MultiModalKwargsItem value: MultiModalKwargsItem
class ProcessingCache: class ProcessingCache(MultiModalCache):
@staticmethod
def get_lru_cache(
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
def get_leaf_size(leaf: object) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return get_item_size(leaf.data)
# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return get_item_size(leaf_data)
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
return sys.getsizeof(leaf)
def get_item_size(
value: Union[MultiModalKwargs, MultiModalKwargsItem,
Mapping[str, NestedTensors]]
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(get_leaf_size, value),
)
if debug:
logger.debug("Calculated size of %s to be %.2f GiB",
type(value), size / GiB_bytes)
return size
return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)
def __init__( def __init__(
self, self,
......
...@@ -429,8 +429,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, ...@@ -429,8 +429,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
if mm_positions and len(mm_positions) != len(mm_hashes): if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError( raise ValueError(
"The number of multi-modal positions and hashes must match. This " "The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. " "is likely because you did not enable MM hashing. "
"Please set disable_mm_preprocessor_cache=False.") "Please set `disable_mm_preprocessor_cache=False`.")
# Note that we assume mm_positions is sorted by offset. # Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of # We do not need to check all mm inputs if the start token index is out of
......
...@@ -35,7 +35,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, ...@@ -35,7 +35,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType, ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput, UtilityResult) UtilityOutput, UtilityResult)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
...@@ -124,8 +124,7 @@ class EngineCore: ...@@ -124,8 +124,7 @@ class EngineCore:
log_stats=self.log_stats, log_stats=self.log_stats,
) )
# Setup MM Input Mapper. self.mm_input_cache_server = MultiModalInputCacheServer(
self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config) vllm_config.model_config)
# Setup batch queue for pipeline parallelism. # Setup batch queue for pipeline parallelism.
...@@ -413,7 +412,7 @@ class EngineCore: ...@@ -413,7 +412,7 @@ class EngineCore:
# Note on thread safety: no race condition. # Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init, # `mm_input_cache_server` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards. # and will only accessed in the input processing thread afterwards.
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs, request.mm_hashes) request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import TYPE_CHECKING, Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.utils import is_list_of from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and if TYPE_CHECKING:
from vllm.config import ModelConfig
# The idea of multimodal input caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the # a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1). # server in the core process (=P1).
# #
# -- Client: # -- P0:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs # - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# with built-in caching functionality, with mm_hash as its identifier. # each input multi-modal item (e.g. image),
# - MirroredProcessingCache to keep track of the cached entries and # - BaseMultiModalProcessor processes the input items into `mm_inputs`,
# determine whether to send the MultiModalKwargs to P1. # which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking
# up additional memory in P0.
# - The `mm_hash` is always sent to P1.
# - The corresponding `mm_inputs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer.
# #
# -- Server: # -- P1:
# - MirroredProcessingCache to store the MultiModalKwargs from P0. # - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_inputs`.
# - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0),
# MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to
# the engine for model execution.
# #
# The caching for both client and server is mirrored, and this allows us # Both Client and Server must perform cache update and eviction based on the
# to avoid the serialization of "mm_inputs" (like pixel values) between # same item size. This ensures that the keys of MultiModalInputCacheClient
# client (=P0) and server (=P1) processes if the mm_hash is found in the client # and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# cache. # whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
# Both Client and Server must use the same cache size
# (to perform mirrored caching). This cache size is set by the environment
# variable VLLM_MM_INPUT_CACHE_GIB.
class MultiModalInputCacheClient:
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
class MirroredProcessingCache: def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
def __init__(self, model_config): self.enabled = model_config.enable_mm_input_cache
mm_config = model_config.multimodal_config self.mm_cache = MultiModalCache.get_lru_cache(
disable_mm_preprocessor_cache = ( model_config.get_mm_input_cache_gb(),
mm_config is not None and mm_config.disable_mm_preprocessor_cache) MultiModalCacheItemMetadata,
self.use_cache = not disable_mm_preprocessor_cache )
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def get_and_update_p0( def get_and_update(
self, self,
mm_inputs: Sequence[MultiModalKwargs], mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str], mm_hashes: list[str],
) -> Sequence[Optional[MultiModalKwargs]]: ) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes) assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache: if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs) assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs return mm_inputs
...@@ -57,20 +71,37 @@ class MirroredProcessingCache: ...@@ -57,20 +71,37 @@ class MirroredProcessingCache:
if self.mm_cache.get(mm_hash) is not None: if self.mm_cache.get(mm_hash) is not None:
mm_input = None mm_input = None
else: else:
self.mm_cache[mm_hash] = mm_input self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_input)
full_mm_inputs.append(mm_input) full_mm_inputs.append(mm_input)
return full_mm_inputs return full_mm_inputs
def get_and_update_p1( def reset(self) -> None:
self.mm_cache.clear()
class MultiModalInputCacheServer:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
self.enabled = model_config.enable_mm_input_cache
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalKwargs,
)
def get_and_update(
self, self,
mm_inputs: Sequence[Optional[MultiModalKwargs]], mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str], mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]: ) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes) assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache: if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs) assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs return mm_inputs
...@@ -85,7 +116,5 @@ class MirroredProcessingCache: ...@@ -85,7 +116,5 @@ class MirroredProcessingCache:
return full_mm_inputs return full_mm_inputs
def reset(self) -> bool: def reset(self) -> None:
self.mm_cache.clear() self.mm_cache.clear()
return True
...@@ -19,7 +19,7 @@ from vllm.pooling_params import PoolingParams ...@@ -19,7 +19,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar) validate_guidance_grammar)
from vllm.v1.structured_output.backend_outlines import ( from vllm.v1.structured_output.backend_outlines import (
...@@ -50,11 +50,8 @@ class Processor: ...@@ -50,11 +50,8 @@ class Processor:
self.tokenizer, self.tokenizer,
mm_registry) mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config) self.mm_input_cache_client = MultiModalInputCacheClient(
self.model_config)
# Multi-modal hasher (for images)
self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching
@property @property
def mm_registry(self): def mm_registry(self):
...@@ -256,11 +253,13 @@ class Processor: ...@@ -256,11 +253,13 @@ class Processor:
# 1. Tokenize text prompt, with LoRA request if one exists. # 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess # 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly. # multimodal data and expand prompt token ids accordingly.
return_mm_hashes = (self.model_config.processor_return_mm_hashes
or bool(self.cache_config.enable_prefix_caching))
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=self.use_hash, return_mm_hashes=return_mm_hashes,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.validate_request( current_platform.validate_request(
...@@ -312,7 +311,7 @@ class Processor: ...@@ -312,7 +311,7 @@ class Processor:
sorted_mm_hashes, sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata( ) = merge_and_sort_multimodal_metadata(
decoder_inputs["mm_placeholders"], decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if self.use_hash else None, decoder_inputs["mm_hashes"] if return_mm_hashes else None,
) )
# The output of merged multi-modal processor (`decoder_mm_inputs`) # The output of merged multi-modal processor (`decoder_mm_inputs`)
...@@ -339,7 +338,7 @@ class Processor: ...@@ -339,7 +338,7 @@ class Processor:
] ]
if sorted_mm_hashes is not None: if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0( sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
orig_sorted_mm_inputs, sorted_mm_hashes) orig_sorted_mm_inputs, sorted_mm_hashes)
else: else:
sorted_mm_inputs = orig_sorted_mm_inputs sorted_mm_inputs = orig_sorted_mm_inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment