Unverified Commit ebb9cc5f authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[UX][Startup] Account for CUDA graphs during memory profiling (#30515)

parent 85f50eb4
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
import weakref
from collections import Counter from collections import Counter
from collections.abc import Callable from collections.abc import Callable
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any from typing import Any, ClassVar
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -162,6 +163,14 @@ class CUDAGraphWrapper: ...@@ -162,6 +163,14 @@ class CUDAGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
""" """
_all_instances: ClassVar[weakref.WeakSet["CUDAGraphWrapper"]] = weakref.WeakSet()
@classmethod
def clear_all_graphs(cls) -> None:
"""Clear captured graphs from all CUDAGraphWrapper instances."""
for instance in list(cls._all_instances):
instance.clear_graphs()
def __init__( def __init__(
self, self,
runnable: Callable[..., Any], runnable: Callable[..., Any],
...@@ -192,6 +201,8 @@ class CUDAGraphWrapper: ...@@ -192,6 +201,8 @@ class CUDAGraphWrapper:
# cudagraphs for. # cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {} self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
CUDAGraphWrapper._all_instances.add(self)
def __getattr__(self, key: str) -> Any: def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable. # allow accessing the attributes of the runnable.
if hasattr(self.runnable, key): if hasattr(self.runnable, key):
...@@ -205,6 +216,13 @@ class CUDAGraphWrapper: ...@@ -205,6 +216,13 @@ class CUDAGraphWrapper:
# in case we need to access the original runnable. # in case we need to access the original runnable.
return self.runnable return self.runnable
@property
def cudagraph_wrapper(self) -> "CUDAGraphWrapper":
return self
def clear_graphs(self) -> None:
self.concrete_cudagraph_entries.clear()
def __call__(self, *args: Any, **kwargs: Any) -> Any | None: def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context() forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor batch_descriptor = forward_context.batch_descriptor
......
...@@ -244,6 +244,7 @@ if TYPE_CHECKING: ...@@ -244,6 +244,7 @@ if TYPE_CHECKING:
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1628,6 +1629,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1628,6 +1629,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool( "VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0")) int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0"))
), ),
# If set to 1, enable CUDA graph memory estimation during memory profiling.
# This profiles CUDA graph memory usage to provide more accurate KV cache
# memory allocation. Disabled by default to preserve existing behavior.
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool(
int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0"))
),
} }
......
...@@ -334,8 +334,11 @@ class CudagraphDispatcher: ...@@ -334,8 +334,11 @@ class CudagraphDispatcher:
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]: for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
descs = list(self.cudagraph_keys[mode]) descs = list(self.cudagraph_keys[mode])
if descs: if descs:
# Sort by num_tokens descending (largest first) # Sort by (num_tokens, num_active_loras) descending
descs.sort(key=lambda d: d.num_tokens, reverse=True) descs.sort(
key=lambda d: (d.num_tokens, d.num_active_loras),
reverse=True,
)
result.append((mode, descs)) result.append((mode, descs))
return result return result
...@@ -29,6 +29,7 @@ from vllm.config import ( ...@@ -29,6 +29,7 @@ from vllm.config import (
CUDAGraphMode, CUDAGraphMode,
VllmConfig, VllmConfig,
get_layers_from_vllm_config, get_layers_from_vllm_config,
set_current_vllm_config,
update_config, update_config,
) )
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
...@@ -94,6 +95,7 @@ from vllm.multimodal.inputs import ( ...@@ -94,6 +95,7 @@ from vllm.multimodal.inputs import (
PlaceholderRange, PlaceholderRange,
) )
from vllm.multimodal.utils import group_and_batch_mm_kwargs from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -596,6 +598,17 @@ class GPUModelRunner( ...@@ -596,6 +598,17 @@ class GPUModelRunner(
self.async_output_copy_stream = torch.cuda.Stream() self.async_output_copy_stream = torch.cuda.Stream()
self.prepare_inputs_event = torch.Event() self.prepare_inputs_event = torch.Event()
# self.cudagraph_batch_sizes sorts in ascending order.
if (
self.compilation_config.cudagraph_capture_sizes
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
self.cudagraph_batch_sizes = sorted(
self.compilation_config.cudagraph_capture_sizes
)
else:
self.cudagraph_batch_sizes = []
# Cache the device properties. # Cache the device properties.
self._init_device_properties() self._init_device_properties()
...@@ -4727,6 +4740,7 @@ class GPUModelRunner( ...@@ -4727,6 +4740,7 @@ class GPUModelRunner(
remove_lora: bool = True, remove_lora: bool = True,
is_graph_capturing: bool = False, is_graph_capturing: bool = False,
num_active_loras: int = 0, num_active_loras: int = 0,
profile_seq_lens: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Run a dummy forward pass to warm up/profile run or capture the Run a dummy forward pass to warm up/profile run or capture the
...@@ -4751,6 +4765,9 @@ class GPUModelRunner( ...@@ -4751,6 +4765,9 @@ class GPUModelRunner(
remove_lora: If False, dummy LoRAs are not destroyed after the run remove_lora: If False, dummy LoRAs are not destroyed after the run
num_active_loras: Number of distinct active LoRAs to capture for. num_active_loras: Number of distinct active LoRAs to capture for.
LoRA is activated when num_active_loras > 0. LoRA is activated when num_active_loras > 0.
profile_seq_lens: If provided, use this value for seq_lens instead
of max_query_len. Used to profile attention workspace that
scales with context length.
""" """
mm_config = self.vllm_config.model_config.multimodal_config mm_config = self.vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_encoder_only: if mm_config and mm_config.mm_encoder_only:
...@@ -4881,11 +4898,13 @@ class GPUModelRunner( ...@@ -4881,11 +4898,13 @@ class GPUModelRunner(
# If force_attention is True, we always capture attention. # If force_attention is True, we always capture attention.
# Otherwise, it only happens for cudagraph_runtime_mode=FULL. # Otherwise, it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
if create_mixed_batch: if profile_seq_lens is not None:
seq_lens = profile_seq_lens # type: ignore[assignment]
elif create_mixed_batch:
# In the mixed batch mode (used for FI warmup), we use # In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster. # shorter sequence lengths to run faster.
# TODO(luka) better system for describing dummy batches # TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment]
else: else:
seq_lens = max_query_len # type: ignore[assignment] seq_lens = max_query_len # type: ignore[assignment]
self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[:num_reqs] = seq_lens
...@@ -5298,24 +5317,34 @@ class GPUModelRunner( ...@@ -5298,24 +5317,34 @@ class GPUModelRunner(
self.encoder_cache.clear() self.encoder_cache.clear()
gc.collect() gc.collect()
@instrument(span_name="Capture model") def _init_minimal_kv_cache_for_profiling(self) -> None:
def capture_model(self) -> int: from vllm.v1.core.kv_cache_utils import (
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: get_kv_cache_config_from_groups,
logger.warning( get_kv_cache_groups,
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
) )
return 0
compilation_counter.num_gpu_runner_capture_triggers += 1 kv_cache_spec = self.get_kv_cache_spec()
kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec)
min_blocks = self.compilation_config.max_cudagraph_capture_size or 1
if kv_cache_groups:
page_size = kv_cache_groups[0].kv_cache_spec.page_size_bytes
group_size = max(len(g.layer_names) for g in kv_cache_groups)
available_memory = min_blocks * page_size * group_size
else:
available_memory = 1 # Attention-free model
start_time = time.perf_counter() minimal_config = get_kv_cache_config_from_groups(
self.vllm_config, kv_cache_groups, available_memory=available_memory
)
self.initialize_kv_cache(minimal_config)
self.cache_config.num_gpu_blocks = minimal_config.num_blocks
logger.debug("Initialized minimal KV cache for CUDA graph profiling")
@staticmethod
@contextmanager @contextmanager
def freeze_gc(): def _freeze_gc():
# Optimize garbage collection during CUDA graph capture.
# Clean up, then freeze all remaining objects from being included
# in future collections.
gc.collect() gc.collect()
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
if should_freeze: if should_freeze:
...@@ -5327,11 +5356,148 @@ class GPUModelRunner( ...@@ -5327,11 +5356,148 @@ class GPUModelRunner(
gc.unfreeze() gc.unfreeze()
gc.collect() gc.collect()
def _cleanup_profiling_kv_cache(self) -> None:
torch.accelerator.synchronize()
if hasattr(self, "kv_caches") and self.kv_caches:
for i in range(len(self.kv_caches)):
self.kv_caches[i] = None # type: ignore
self.kv_caches.clear()
if hasattr(self, "cross_layers_kv_cache"):
self.cross_layers_kv_cache = None
self.cross_layers_attn_backend = None
if hasattr(self, "attn_groups"):
self.attn_groups.clear()
if hasattr(self, "kv_cache_config"):
delattr(self, "kv_cache_config")
self.cache_config.num_gpu_blocks = None
for layer in self.compilation_config.static_forward_context.values():
if hasattr(layer, "kv_cache"):
layer.kv_cache = []
gc.collect()
torch.accelerator.empty_cache()
logger.debug("Cleaned up profiling KV cache and CUDA graphs")
@torch.inference_mode()
def profile_cudagraph_memory(self) -> int:
with set_current_vllm_config(self.vllm_config):
self._init_minimal_kv_cache_for_profiling()
saved_num_cudagraph_captured = compilation_counter.num_cudagraph_captured
capture_descs = self.cudagraph_dispatcher.get_capture_descs()
total_graphs = sum(len(descs) for _, descs in capture_descs)
if total_graphs == 0:
logger.debug("No CUDA graphs will be captured, skipping profiling")
self._cleanup_profiling_kv_cache()
return 0
logger.info(
"Profiling CUDA graph memory: %s",
", ".join(
f"{mode.name}={len(descs)} (largest={descs[0].num_tokens})"
for mode, descs in capture_descs
if descs
),
)
# Use a temporary pool for profiling to avoid fragmentation in the main pool.
profiling_pool = current_platform.graph_pool_handle()
original_pools: dict[int, Any] = {}
for instance in list(CUDAGraphWrapper._all_instances):
original_pools[id(instance)] = instance.graph_pool
instance.graph_pool = profiling_pool
set_cudagraph_capturing_enabled(True)
with self._freeze_gc(), graph_capture(device=self.device):
shared_memory_estimate = {}
per_graph_estimate = {}
torch.accelerator.synchronize()
torch.accelerator.empty_cache()
for mode, descs in capture_descs:
profile_descs = descs[:2]
mem_samples: list[int] = []
for i, desc in enumerate(profile_descs):
mem_before = torch.cuda.mem_get_info()[0]
self._warmup_and_capture(
desc,
cudagraph_runtime_mode=mode,
profile_seq_lens=(
min(
self.max_model_len,
self.max_num_tokens // desc.num_tokens,
)
if mode == CUDAGraphMode.FULL and i == 0
else None
),
)
torch.accelerator.synchronize()
free_after = torch.cuda.mem_get_info()[0]
mem_samples.append(mem_before - free_after)
first_capture = mem_samples[0]
# Use at least 1 MiB per graph for driver overhead
per_graph = max(mem_samples[1] if len(mem_samples) > 1 else 0, 1 << 20)
shared_memory_estimate[mode] = first_capture
per_graph_estimate[mode] = per_graph * (len(descs) - 1)
logger.debug(
"Estimated %s CUDA graph memory: "
"%.2f MiB first-capture + (%d-1) × %.2f MiB per-graph",
mode.name,
first_capture / (1 << 20),
len(descs),
per_graph / (1 << 20),
)
set_cudagraph_capturing_enabled(False)
CUDAGraphWrapper.clear_all_graphs()
for instance in list(CUDAGraphWrapper._all_instances):
if id(instance) in original_pools:
instance.graph_pool = original_pools[id(instance)]
self.maybe_remove_all_loras(self.lora_config)
self._cleanup_profiling_kv_cache()
compilation_counter.num_cudagraph_captured = saved_num_cudagraph_captured
# FULL and PIECEWISE graphs share the global pool at runtime and are
# never replayed concurrently, so the pool overlays their memory.
# Take the max to avoid double-counting the overlap.
total_estimate = max(shared_memory_estimate.values()) + sum(
per_graph_estimate.values()
)
logger.info(
"Estimated CUDA graph memory: %.2f GiB total",
total_estimate / (1 << 30),
)
return int(total_estimate)
@instrument(span_name="Capture model")
def capture_model(self) -> int:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return 0
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
# Trigger CUDA graph capture for specific shapes. # Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes # Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes. # can reuse the memory pool allocated for the large shapes.
set_cudagraph_capturing_enabled(True) set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device): with self._freeze_gc(), graph_capture(device=self.device):
torch.accelerator.synchronize()
torch.accelerator.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]
for ( for (
...@@ -5342,6 +5508,7 @@ class GPUModelRunner( ...@@ -5342,6 +5508,7 @@ class GPUModelRunner(
batch_descriptors=batch_descs, batch_descriptors=batch_descs,
cudagraph_runtime_mode=runtime_mode, cudagraph_runtime_mode=runtime_mode,
) )
torch.accelerator.synchronize()
torch.accelerator.synchronize() torch.accelerator.synchronize()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
...@@ -5353,6 +5520,9 @@ class GPUModelRunner( ...@@ -5353,6 +5520,9 @@ class GPUModelRunner(
# after here. # after here.
set_cudagraph_capturing_enabled(False) set_cudagraph_capturing_enabled(False)
torch.accelerator.synchronize()
torch.accelerator.empty_cache()
# Lock workspace to prevent resizing during execution. # Lock workspace to prevent resizing during execution.
# Max workspace sizes should have been captured during warmup/profiling. # Max workspace sizes should have been captured during warmup/profiling.
lock_workspace() lock_workspace()
...@@ -5369,6 +5539,40 @@ class GPUModelRunner( ...@@ -5369,6 +5539,40 @@ class GPUModelRunner(
) )
return cuda_graph_size return cuda_graph_size
def _warmup_and_capture(
self,
desc: BatchDescriptor,
cudagraph_runtime_mode: CUDAGraphMode,
profile_seq_lens: int | None = None,
allow_microbatching: bool = False,
num_warmups: int | None = None,
):
if num_warmups is None:
num_warmups = self.compilation_config.cudagraph_num_of_warmups
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
for _ in range(num_warmups):
self._dummy_run(
desc.num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=desc.uniform,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
num_active_loras=desc.num_active_loras,
)
self._dummy_run(
desc.num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=desc.uniform,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
num_active_loras=desc.num_active_loras,
is_graph_capturing=True,
profile_seq_lens=profile_seq_lens,
)
def _capture_cudagraphs( def _capture_cudagraphs(
self, self,
batch_descriptors: list[BatchDescriptor], batch_descriptors: list[BatchDescriptor],
...@@ -5383,15 +5587,6 @@ class GPUModelRunner( ...@@ -5383,15 +5587,6 @@ class GPUModelRunner(
return return
uniform_decode = batch_descriptors[0].uniform uniform_decode = batch_descriptors[0].uniform
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
dummy_run = functools.partial(
self._dummy_run,
uniform_decode=uniform_decode,
skip_eplb=True,
remove_lora=False,
force_attention=force_attention,
)
# Only rank 0 should print progress bar during capture # Only rank 0 should print progress bar during capture
if is_global_first_rank(): if is_global_first_rank():
...@@ -5406,9 +5601,6 @@ class GPUModelRunner( ...@@ -5406,9 +5601,6 @@ class GPUModelRunner(
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for batch_desc in batch_descriptors: for batch_desc in batch_descriptors:
num_tokens = batch_desc.num_tokens
num_active_loras = batch_desc.num_active_loras
# We currently only capture ubatched graphs when its a FULL # We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens # cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched # is above the threshold. Otherwise we just capture a non-ubatched
...@@ -5419,33 +5611,16 @@ class GPUModelRunner( ...@@ -5419,33 +5611,16 @@ class GPUModelRunner(
and uniform_decode and uniform_decode
and check_ubatch_thresholds( and check_ubatch_thresholds(
config=self.vllm_config.parallel_config, config=self.vllm_config.parallel_config,
num_tokens=num_tokens, num_tokens=batch_desc.num_tokens,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
) )
) )
self._warmup_and_capture(
for _ in range(self.compilation_config.cudagraph_num_of_warmups): batch_desc,
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE` is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
dummy_run(
num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
allow_microbatching=allow_microbatching,
num_active_loras=num_active_loras,
)
# Capture run
dummy_run(
num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
num_active_loras=num_active_loras,
is_graph_capturing=True,
) )
torch.accelerator.synchronize()
self.maybe_remove_all_loras(self.lora_config) self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
......
...@@ -112,16 +112,25 @@ class UBatchWrapper: ...@@ -112,16 +112,25 @@ class UBatchWrapper:
self.cudagraphs: dict[int, CUDAGraphMetaData] = {} self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
self.cudagraph_wrapper = None self.cudagraph_wrapper = None
self.graph_pool = None
if runtime_mode is not CUDAGraphMode.NONE: if runtime_mode is not CUDAGraphMode.NONE:
self.cudagraph_wrapper = CUDAGraphWrapper( self.cudagraph_wrapper = CUDAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode runnable, vllm_config, runtime_mode=runtime_mode
) )
self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config) self.sm_control = self._create_sm_control_context(vllm_config)
self.device = device self.device = device
@property
def graph_pool(self):
if self.cudagraph_wrapper is not None:
return self.cudagraph_wrapper.graph_pool
return None
def clear_graphs(self) -> None:
self.cudagraphs.clear()
if self.cudagraph_wrapper is not None:
self.cudagraph_wrapper.clear_graphs()
@staticmethod @staticmethod
def _create_sm_control_context(vllm_config: VllmConfig): def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms: int = envs.VLLM_DBO_COMM_SMS comm_sms: int = envs.VLLM_DBO_COMM_SMS
......
...@@ -44,6 +44,7 @@ from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper ...@@ -44,6 +44,7 @@ from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tracing import instrument from vllm.tracing import instrument
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
...@@ -390,8 +391,36 @@ class Worker(WorkerBase): ...@@ -390,8 +391,36 @@ class Worker(WorkerBase):
) as profile_result: ) as profile_result:
self.model_runner.profile_run() self.model_runner.profile_run()
profile_torch_peak = current_platform.memory_stats(self.device).get(
"allocated_bytes.all.peak", 0
)
# Profile CUDA graph memory if graphs will be captured.
cudagraph_memory_estimate = 0
if not self.model_config.enforce_eager:
cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory()
# Use the pre-cudagraph torch peak to avoid double-counting.
profile_result.torch_peak_increase = (
profile_torch_peak - profile_result.before_profile.torch_peak
)
profile_result.non_kv_cache_memory = (
profile_result.non_torch_increase
+ profile_result.torch_peak_increase
+ profile_result.weights_memory
)
cudagraph_memory_estimate_applied = (
cudagraph_memory_estimate
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
else 0
)
self.non_torch_memory = profile_result.non_torch_increase self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase self.peak_activation_memory = (
profile_result.torch_peak_increase + cudagraph_memory_estimate_applied
)
self.cudagraph_memory_estimate = cudagraph_memory_estimate
free_gpu_memory = profile_result.after_profile.free_memory free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same # NOTE(woosuk): Here we assume that the other processes using the same
...@@ -406,7 +435,9 @@ class Worker(WorkerBase): ...@@ -406,7 +435,9 @@ class Worker(WorkerBase):
"isolate vLLM in its own container." "isolate vLLM in its own container."
) )
self.available_kv_cache_memory_bytes = ( self.available_kv_cache_memory_bytes = (
self.requested_memory - profile_result.non_kv_cache_memory self.requested_memory
- profile_result.non_kv_cache_memory
- cudagraph_memory_estimate_applied
) )
unrequested_memory = self.init_snapshot.free_memory - self.requested_memory unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
...@@ -428,6 +459,46 @@ class Worker(WorkerBase): ...@@ -428,6 +459,46 @@ class Worker(WorkerBase):
scope="local", scope="local",
) )
if cudagraph_memory_estimate > 0:
total_mem = self.init_snapshot.total_memory
current_util = self.cache_config.gpu_memory_utilization
cg_util_delta = cudagraph_memory_estimate / total_mem
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS:
equiv_util = round(current_util - cg_util_delta, 4)
suggested_util = min(
round(current_util + cg_util_delta, 4),
1.0,
)
logger.info(
"CUDA graph memory profiling is enabled "
"(VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). "
"This will become the default in v0.19. "
"The current --gpu-memory-utilization=%.4f is equivalent "
"to --gpu-memory-utilization=%.4f without CUDA graph "
"memory profiling. To maintain the same effective KV "
"cache size as before, increase "
"--gpu-memory-utilization to %.4f.",
current_util,
equiv_util,
suggested_util,
)
else:
suggested_util = min(
round(current_util + cg_util_delta, 4),
1.0,
)
logger.info(
"In v0.19, CUDA graph memory profiling will be enabled "
"by default (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1), "
"which more accurately accounts for CUDA graph memory "
"during KV cache allocation. To try it now, set "
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 and increase "
"--gpu-memory-utilization from %.4f to %.4f to maintain "
"the same effective KV cache size.",
current_util,
suggested_util,
)
return int(self.available_kv_cache_memory_bytes) return int(self.available_kv_cache_memory_bytes)
def get_kv_connector_handshake_metadata(self) -> dict | None: def get_kv_connector_handshake_metadata(self) -> dict | None:
...@@ -487,14 +558,14 @@ class Worker(WorkerBase): ...@@ -487,14 +558,14 @@ class Worker(WorkerBase):
@instrument(span_name="Warmup (GPU)") @instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> float: def compile_or_warm_up_model(self) -> float:
warmup_sizes = [] warmup_sizes: list[int] = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
# warm up sizes that are not in cudagraph capture sizes, # warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance, # but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill. # e.g. for the max-num-batched token size in chunked prefill.
compile_sizes = self.vllm_config.compilation_config.compile_sizes compile_sizes = self.vllm_config.compilation_config.compile_sizes
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] # type: ignore[assignment]
cg_capture_sizes: list[int] = [] cg_capture_sizes: list[int] = []
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
...@@ -526,6 +597,22 @@ class Worker(WorkerBase): ...@@ -526,6 +597,22 @@ class Worker(WorkerBase):
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
cuda_graph_memory_bytes = self.model_runner.capture_model() cuda_graph_memory_bytes = self.model_runner.capture_model()
# Compare actual vs estimated CUDA graph memory (if we did profiling)
if (
hasattr(self, "cudagraph_memory_estimate")
and self.cudagraph_memory_estimate > 0
):
GiB = lambda b: round(b / GiB_bytes, 2)
diff = abs(cuda_graph_memory_bytes - self.cudagraph_memory_estimate)
logger.info(
"CUDA graph pool memory: %s GiB (actual), %s GiB (estimated), "
"difference: %s GiB (%.1f%%).",
GiB(cuda_graph_memory_bytes),
GiB(self.cudagraph_memory_estimate),
GiB(diff),
100 * diff / max(cuda_graph_memory_bytes, 1),
)
if self.cache_config.kv_cache_memory_bytes is None and hasattr( if self.cache_config.kv_cache_memory_bytes is None and hasattr(
self, "peak_activation_memory" self, "peak_activation_memory"
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment