"vscode:/vscode.git/clone" did not exist on "135cf55cd1d83cd4e18266e343a59e6d9f87856f"
Unverified Commit 2497228a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Factor out logic for requesting initial memory (#30868)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 196cdc32
...@@ -66,27 +66,43 @@ class MemorySnapshot: ...@@ -66,27 +66,43 @@ class MemorySnapshot:
torch_memory: int = 0 torch_memory: int = 0
non_torch_memory: int = 0 non_torch_memory: int = 0
timestamp: float = 0.0 timestamp: float = 0.0
device: torch.types.Device = None
auto_measure: bool = True auto_measure: bool = True
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.device is None:
from vllm.platforms import current_platform
device_fn = current_platform.current_device
assert device_fn is not None
self.device_ = torch.device(device_fn())
else:
self.device_ = torch.device(self.device)
if self.auto_measure: if self.auto_measure:
self.measure() self.measure()
def measure(self) -> None: def measure(self) -> None:
from vllm.platforms import current_platform from vllm.platforms import current_platform
device = self.device_
# we measure the torch peak memory usage via allocated_bytes, # we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` . # rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`, # After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink # `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens. # when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) self.torch_peak = torch.cuda.memory_stats(device).get(
"allocated_bytes.all.peak", 0
)
self.free_memory, self.total_memory = torch.cuda.mem_get_info() self.free_memory, self.total_memory = torch.cuda.mem_get_info(device)
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
if ( if (
current_platform.is_cuda() current_platform.is_cuda()
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms and current_platform.get_device_capability(device.index)
in shared_sysmem_device_mem_sms
): ):
# On UMA (Orin, Thor and Spark) platform, # On UMA (Orin, Thor and Spark) platform,
# where both CPU and GPU rely on system memory, # where both CPU and GPU rely on system memory,
...@@ -106,12 +122,18 @@ class MemorySnapshot: ...@@ -106,12 +122,18 @@ class MemorySnapshot:
# torch.cuda.memory_reserved() is how many bytes # torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.) # PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage # this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved() self.torch_memory = torch.cuda.memory_reserved(device)
self.non_torch_memory = self.cuda_memory - self.torch_memory self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time() self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
if self.device_ != other.device_:
raise ValueError(
"The two snapshots should be from the same device! "
f"Found: {self.device_} vs. {other.device_}"
)
return MemorySnapshot( return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak, torch_peak=self.torch_peak - other.torch_peak,
free_memory=self.free_memory - other.free_memory, free_memory=self.free_memory - other.free_memory,
...@@ -120,6 +142,7 @@ class MemorySnapshot: ...@@ -120,6 +142,7 @@ class MemorySnapshot:
torch_memory=self.torch_memory - other.torch_memory, torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory, non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp, timestamp=self.timestamp - other.timestamp,
device=self.device_,
auto_measure=False, auto_measure=False,
) )
......
...@@ -56,6 +56,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp ...@@ -56,6 +56,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from .utils import request_memory
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -237,22 +239,8 @@ class Worker(WorkerBase): ...@@ -237,22 +239,8 @@ class Worker(WorkerBase):
torch.cuda.empty_cache() torch.cuda.empty_cache()
# take current memory snapshot # take current memory snapshot
self.init_snapshot = MemorySnapshot() self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
self.requested_memory = ( self.requested_memory = request_memory(init_snapshot, self.cache_config)
self.init_snapshot.total_memory
* self.cache_config.gpu_memory_utilization
)
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
......
...@@ -8,13 +8,15 @@ from typing_extensions import deprecated ...@@ -8,13 +8,15 @@ from typing_extensions import deprecated
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
...@@ -248,6 +250,28 @@ def gather_mm_placeholders( ...@@ -248,6 +250,28 @@ def gather_mm_placeholders(
return placeholders[is_embed] return placeholders[is_embed]
def request_memory(init_snapshot: MemorySnapshot, cache_config: CacheConfig) -> float:
"""
Calculate the amount of memory required by vLLM, then validate
that the current amount of free memory is sufficient for that.
"""
requested_memory = init_snapshot.total_memory * cache_config.gpu_memory_utilization
if init_snapshot.free_memory < requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device {init_snapshot.device_} "
f"({GiB(init_snapshot.free_memory)}/"
f"{GiB(init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({cache_config.gpu_memory_utilization}, "
f"{GiB(requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
return requested_memory
def add_kv_sharing_layers_to_kv_cache_groups( def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str], shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec], kv_cache_groups: list[KVCacheGroupSpec],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment