Unverified Commit 21997f45 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Redo] #33110 with threading limit (#33502)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarYunzhuLu <lucia.yunzhu@gmail.com>
parent 67202387
...@@ -17,6 +17,7 @@ from packaging.version import Version ...@@ -17,6 +17,7 @@ from packaging.version import Version
from torch.library import Library, infer_schema from torch.library import Library, infer_schema
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -25,9 +26,7 @@ else: ...@@ -25,9 +26,7 @@ else:
ModelConfig = object ModelConfig = object
IntermediateTensors = object IntermediateTensors = object
import logging logger = init_logger(__name__)
logger = logging.getLogger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
...@@ -104,12 +103,36 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -104,12 +103,36 @@ def set_default_torch_dtype(dtype: torch.dtype):
@contextlib.contextmanager @contextlib.contextmanager
def set_default_torch_num_threads(num_threads: int): def set_default_torch_num_threads(num_threads: int | None = None):
"""Sets the default number of threads for PyTorch to the given value.""" """
Sets the default number of threads for PyTorch to the given value.
`None` means using the value of the environment variable `OMP_NUM_THREADS`
(or `1` if that is not available).
"""
if num_threads is None:
num_threads = 1
try:
num_threads = int(os.environ["OMP_NUM_THREADS"])
except KeyError:
logger.debug_once(
"OMP_NUM_THREADS is not set; defaulting Torch threads to %d.",
num_threads,
)
except ValueError:
logger.warning_once(
"OMP_NUM_THREADS is invalid; defaulting Torch threads to %d.",
num_threads,
)
old_num_threads = torch.get_num_threads() old_num_threads = torch.get_num_threads()
torch.set_num_threads(num_threads) torch.set_num_threads(num_threads)
yield
torch.set_num_threads(old_num_threads) try:
yield
finally:
torch.set_num_threads(old_num_threads)
@contextlib.contextmanager @contextlib.contextmanager
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Literal, cast from typing import Any, Literal, cast
...@@ -35,6 +34,7 @@ from vllm.tokenizers import TokenizerLike ...@@ -35,6 +34,7 @@ from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
...@@ -68,6 +68,19 @@ class InputProcessor: ...@@ -68,6 +68,19 @@ class InputProcessor:
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config) self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
self.mm_encoder_cache_size = None
if (
self.mm_registry.supports_multimodal_inputs(self.model_config)
and not self.model_config.skip_tokenizer_init
):
with set_default_torch_num_threads():
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_modality(self.model_config)
)
_, self.mm_encoder_cache_size = compute_mm_encoder_budget(
self.vllm_config.scheduler_config, max_tokens_by_modality
)
self.input_preprocessor = InputPreprocessor( self.input_preprocessor = InputPreprocessor(
self.model_config, self.model_config,
...@@ -534,15 +547,7 @@ class InputProcessor: ...@@ -534,15 +547,7 @@ class InputProcessor:
# 1. Tokenize text prompt, with LoRA request if one exists. # 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess # 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly. # multimodal data and expand prompt token ids accordingly.
num_threads = int(os.environ.get("OMP_NUM_THREADS", "1")) with set_request_id(request_id), set_default_torch_num_threads():
if "OMP_NUM_THREADS" not in os.environ:
logger.debug_once(
"OMP_NUM_THREADS is not set; defaulting Torch threads to %d for "
"input preprocessing.",
num_threads,
)
with set_request_id(request_id), set_default_torch_num_threads(num_threads):
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
...@@ -743,6 +748,25 @@ class InputProcessor: ...@@ -743,6 +748,25 @@ class InputProcessor:
f"model length of {max_prompt_len}. {suggestion}" f"model length of {max_prompt_len}. {suggestion}"
) )
if (
prompt_type == "decoder"
and prompt_inputs["type"] == "multimodal"
and self.mm_encoder_cache_size is not None
):
decoder_mm_positions = prompt_inputs["mm_placeholders"]
for modality, mm_positions in decoder_mm_positions.items():
for mm_position in mm_positions:
embed_length = mm_position.get_num_embeds
if embed_length > self.mm_encoder_cache_size:
raise ValueError(
f"The {prompt_type} prompt contains a(n) {modality} item "
f"with length {embed_length}, which exceeds the "
f"pre-allocated encoder cache size "
f"{self.mm_encoder_cache_size}. Please reduce the input "
f"size or increase the encoder cache size "
f"by setting --limit-mm-per-prompt at startup."
)
def stat_mm_cache(self) -> MultiModalCacheStats | None: def stat_mm_cache(self) -> MultiModalCacheStats | None:
return self.input_preprocessor.stat_mm_cache() return self.input_preprocessor.stat_mm_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment