Unverified Commit 3556a414 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Limit multimodal input cache by memory (#14805)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 9ed6ee92
...@@ -53,7 +53,7 @@ repos: ...@@ -53,7 +53,7 @@ repos:
entry: tools/mypy.sh 0 "local" entry: tools/mypy.sh 0 "local"
language: python language: python
types: [python] types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests] additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
stages: [pre-commit] # Don't run in CI stages: [pre-commit] # Don't run in CI
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9 name: Run mypy for Python 3.9
......
cachetools
psutil psutil
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0 numpy < 2.0.0
......
...@@ -9,6 +9,7 @@ msgspec ...@@ -9,6 +9,7 @@ msgspec
cloudpickle cloudpickle
# packages to install to build the documentation # packages to install to build the documentation
cachetools
pydantic >= 2.8 pydantic >= 2.8
-f https://download.pytorch.org/whl/cpu -f https://download.pytorch.org/whl/cpu
torch torch
......
...@@ -48,7 +48,7 @@ def _test_processing_correctness( ...@@ -48,7 +48,7 @@ def _test_processing_correctness(
tokenizer=cached_tokenizer_from_config(model_config), tokenizer=cached_tokenizer_from_config(model_config),
) )
# Ensure that it can fit all of the data # Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30) cache = ProcessingCache(capacity_gb=2048)
processing_info = factories.info(ctx) processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits() supported_mm_limits = processing_info.get_supported_mm_limits()
......
...@@ -56,7 +56,7 @@ if TYPE_CHECKING: ...@@ -56,7 +56,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MM_INPUT_CACHE_SIZE: int = 256 VLLM_MM_INPUT_CACHE_GIB: int = 8
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
...@@ -432,11 +432,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -432,11 +432,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT": "VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Cache size for multimodal feature/input cache for multimodal models # Cache size (in GiB) for multimodal input cache
# in unit of number of multimodal data items (e.g. image, video, audio). # Default is 8GiB
# Default is 256 multimodal data items. "VLLM_MM_INPUT_CACHE_GIB":
"VLLM_MM_INPUT_CACHE_SIZE": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "8")),
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_SIZE", "256")),
# Path to the XLA persistent cache directory. # Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs. # Only used for XLA devices such as TPUs.
......
# SPDX-License-Identifier: Apache-2.0
"""Helper functions to work with nested JSON structures."""
from collections.abc import Iterable
from functools import reduce
from typing import Callable, TypeVar, Union, overload
_T = TypeVar("_T")
_U = TypeVar("_U")
JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...], _T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
"""Iterate through each leaf in a nested JSON structure."""
if isinstance(value, dict):
for v in value.values():
yield from json_iter_leaves(v)
elif isinstance(value, (list, tuple)):
for v in value:
yield from json_iter_leaves(v)
else:
yield value
def json_map_leaves(
func: Callable[[_T], _U],
value: JSONTree[_T],
) -> JSONTree[_U]:
"""Apply a function to each leaf in a nested JSON structure."""
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
return func(value)
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: JSONTree[_T],
/,
) -> _T:
...
@overload
def json_reduce_leaves(
func: Callable[[_U, _T], _U],
value: JSONTree[_T],
initial: _U,
/,
) -> _U:
...
def json_reduce_leaves(
func: Callable[..., Union[_T, _U]],
value: JSONTree[_T],
initial: _U = ..., # type: ignore[assignment]
/,
) -> Union[_T, _U]:
"""
Apply a function of two arguments cumulatively to each leaf in a
nested JSON structure, from left to right, so as to reduce the
sequence to a single value.
"""
if initial is ...:
return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type]
return reduce(
func, # type: ignore[arg-type]
json_iter_leaves(value),
initial,
)
...@@ -18,6 +18,7 @@ from transformers.models.pixtral import PixtralProcessor ...@@ -18,6 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -35,7 +36,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -35,7 +36,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves from vllm.utils import flatten_2d_lists
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......
...@@ -24,6 +24,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, ...@@ -24,6 +24,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
SiluAndMul) SiluAndMul)
...@@ -50,7 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -50,7 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate) PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant) SupportsMultiModal, SupportsPP, SupportsQuant)
......
...@@ -16,7 +16,8 @@ from PIL.Image import Image ...@@ -16,7 +16,8 @@ from PIL.Image import Image
from transformers import BatchFeature from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves from vllm.jsontree import JSONTree, json_map_leaves
from vllm.utils import full_groupby, is_list_of
if TYPE_CHECKING: if TYPE_CHECKING:
from .hasher import MultiModalHashDict from .hasher import MultiModalHashDict
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re import re
import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
...@@ -11,14 +11,17 @@ from functools import lru_cache ...@@ -11,14 +11,17 @@ from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast) TypeVar, Union, cast)
import torch
from cachetools import LRUCache
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
...@@ -812,25 +815,50 @@ def find_mm_placeholders( ...@@ -812,25 +815,50 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")
class ProcessingCache: class ProcessingCache:
def __init__(self, capacity: int) -> None: @staticmethod
def get_lru_cache(
capacity_gb: int,
value_type: type[_V],
) -> LRUCache[str, _V]:
def get_size(leaf: object) -> int:
if isinstance(leaf, torch.Tensor):
return leaf.nbytes # sys.getsizeof doesn't work for tensors
return sys.getsizeof(leaf)
return LRUCache[str, _V](
GiB_bytes * capacity_gb,
getsizeof=lambda x: json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(get_size, x),
),
)
def __init__(self, capacity_gb: int) -> None:
super().__init__() super().__init__()
# DEBUG: Set to None to disable # DEBUG: Set to None to disable
self.debug_cache_hit_ratio_steps: Optional[int] = None self.debug_cache_hit_ratio_steps: Optional[int] = None
self.debug_cache_hits = 0
self.debug_cache_total = 0
self._cache = LRUCache[str, MultiModalKwargsItem](capacity) self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
def _maybe_log_cache_stats(self) -> None: def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps steps = self.debug_cache_hit_ratio_steps
if not steps: if not steps:
return return
cache_stats = self._cache.stat() total = self.debug_cache_total
if cache_stats.total % steps == 0: if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f", logger.debug("ProcessingCache: hit_ratio = %.2f",
cache_stats.hit_ratio) self.debug_cache_hits / total)
def get( def get(
self, self,
...@@ -853,6 +881,13 @@ class ProcessingCache: ...@@ -853,6 +881,13 @@ class ProcessingCache:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item}, **{modality: input_item},
**input_kwargs) **input_kwargs)
if self.debug_cache_hit_ratio_steps:
if cache_key in self._cache:
self.debug_cache_hits += 1
self.debug_cache_total += 1
return self._cache.get(cache_key) return self._cache.get(cache_key)
def put( def put(
...@@ -870,7 +905,7 @@ class ProcessingCache: ...@@ -870,7 +905,7 @@ class ProcessingCache:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item}, **{modality: input_item},
**input_kwargs) **input_kwargs)
self._cache.put(cache_key, output_kwargs) self._cache[cache_key] = output_kwargs
class BaseProcessingInfo: class BaseProcessingInfo:
......
...@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar ...@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar
import torch.nn as nn import torch.nn as nn
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer,
...@@ -119,7 +119,7 @@ class MultiModalRegistry: ...@@ -119,7 +119,7 @@ class MultiModalRegistry:
self._limits_by_model = _MultiModalLimits() self._limits_by_model = _MultiModalLimits()
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_SIZE) self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
def register_plugin(self, plugin: MultiModalPlugin) -> None: def register_plugin(self, plugin: MultiModalPlugin) -> None:
""" """
......
...@@ -845,22 +845,6 @@ def is_list_of( ...@@ -845,22 +845,6 @@ def is_list_of(
assert_never(check) assert_never(check)
JSONTree = Union[dict[str, "JSONTree[T]"], list["JSONTree[T]"],
tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
return func(value)
def flatten_2d_lists(lists: list[list[T]]) -> list[T]: def flatten_2d_lists(lists: list[list[T]]) -> list[T]:
"""Flatten a list of lists to a single list.""" """Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist] return [item for sublist in lists for item in sublist]
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
from typing import Any, Optional from typing import Any, Optional
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry) MultiModalKwargs, MultiModalRegistry)
from vllm.utils import LRUCache from vllm.multimodal.processing import ProcessingCache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -30,7 +30,7 @@ logger = init_logger(__name__) ...@@ -30,7 +30,7 @@ logger = init_logger(__name__)
# Both Client and Server must use the same cache size # Both Client and Server must use the same cache size
# (to perform mirrored caching). This cache size is set by the environment # (to perform mirrored caching). This cache size is set by the environment
# variable VLLM_MM_INPUT_CACHE_SIZE. # variable VLLM_MM_INPUT_CACHE_GIB.
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use # TODO(ywang96): Deprecate this class once all multimodal models migrate to use
...@@ -50,18 +50,20 @@ class MMInputCacheClient: ...@@ -50,18 +50,20 @@ class MMInputCacheClient:
# Init cache # Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str, self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE) MultiModalKwargs)
# DEBUG: Set to None to disable # DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None self.mm_debug_cache_hit_ratio_steps = None
self.mm_cache_hits = 0 self.mm_debug_cache_hits = 0
self.mm_cache_total = 0 self.mm_debug_cache_total = 0
def cache_hit_ratio(self, steps): def cache_hit_ratio(self, steps):
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0: total = self.mm_debug_cache_total
if total > 0 and total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total) self.mm_debug_cache_hits / total)
# NOTE: process_inputs only supports image inputs since all multimodal # NOTE: process_inputs only supports image inputs since all multimodal
# models with other modalities have migrated to use merged preprocessor. # models with other modalities have migrated to use merged preprocessor.
...@@ -71,7 +73,7 @@ class MMInputCacheClient: ...@@ -71,7 +73,7 @@ class MMInputCacheClient:
mm_hashes: Optional[list[str]], mm_hashes: Optional[list[str]],
mm_processor_kwargs: Optional[dict[str, Any]], mm_processor_kwargs: Optional[dict[str, Any]],
precomputed_mm_inputs: Optional[list[MultiModalKwargs]], precomputed_mm_inputs: Optional[list[MultiModalKwargs]],
) -> list[MultiModalKwargs]: ) -> list[Optional[MultiModalKwargs]]:
if precomputed_mm_inputs is None: if precomputed_mm_inputs is None:
image_inputs = mm_data["image"] image_inputs = mm_data["image"]
if not isinstance(image_inputs, list): if not isinstance(image_inputs, list):
...@@ -88,7 +90,7 @@ class MMInputCacheClient: ...@@ -88,7 +90,7 @@ class MMInputCacheClient:
# Process each image input separately, so that later we can schedule # Process each image input separately, so that later we can schedule
# them in a fine-grained manner. # them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided) # Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_inputs: list[MultiModalKwargs] = [] ret_inputs: list[Optional[MultiModalKwargs]] = []
for input_id in range(num_inputs): for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None: if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
...@@ -99,7 +101,7 @@ class MMInputCacheClient: ...@@ -99,7 +101,7 @@ class MMInputCacheClient:
mm_hash = mm_hashes[input_id] mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash) mm_input = self.mm_cache.get(mm_hash)
self.mm_cache_total += 1 self.mm_debug_cache_total += 1
if mm_input is None: if mm_input is None:
if precomputed_mm_inputs is not None: if precomputed_mm_inputs is not None:
# Reuse precomputed input (for merged preprocessor) # Reuse precomputed input (for merged preprocessor)
...@@ -114,9 +116,9 @@ class MMInputCacheClient: ...@@ -114,9 +116,9 @@ class MMInputCacheClient:
if self.use_cache: if self.use_cache:
# Add to cache # Add to cache
assert mm_hash is not None assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input) self.mm_cache[mm_hash] = mm_input
else: else:
self.mm_cache_hits += 1 self.mm_debug_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server mm_input = None # Avoids sending mm_input to Server
ret_inputs.append(mm_input) ret_inputs.append(mm_input)
...@@ -128,14 +130,14 @@ class MMInputCacheServer: ...@@ -128,14 +130,14 @@ class MMInputCacheServer:
def __init__(self, model_config): def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str, self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE) MultiModalKwargs)
def get_and_update( def get_and_update(
self, self,
mm_inputs: list[Optional[MultiModalKwargs]], mm_inputs: list[Optional[MultiModalKwargs]],
mm_hashes: list[str], mm_hashes: list[str],
) -> list[MultiModalKwargs]: ) -> list[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes) assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache: if not self.use_cache:
...@@ -148,7 +150,7 @@ class MMInputCacheServer: ...@@ -148,7 +150,7 @@ class MMInputCacheServer:
mm_input = self.mm_cache.get(mm_hash) mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None assert mm_input is not None
else: else:
self.mm_cache.put(mm_hash, mm_input) self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input) full_mm_inputs.append(mm_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment