Unverified Commit 95995bbe authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][Engine] Fix GPU memory leaks in engine shutdown and test workaround for...


[ROCm][Engine] Fix GPU memory leaks in engine shutdown and test workaround for async KV prefix cache reset (#38503)
Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 07351e08
......@@ -124,10 +124,10 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1
# RIXL/UCX build stages
FROM base AS build_rixl
ARG RIXL_BRANCH="f33a5599"
ARG RIXL_BRANCH="bf4a7214"
ARG RIXL_REPO="https://github.com/ROCm/RIXL.git"
ARG UCX_BRANCH="da3fac2a"
ARG UCX_REPO="https://github.com/ROCm/ucx.git"
ARG UCX_BRANCH="7009d7a1"
ARG UCX_REPO="https://github.com/openucx/ucx.git"
ENV ROCM_PATH=/opt/rocm
ENV UCX_HOME=/usr/local/ucx
ENV RIXL_HOME=/usr/local/rixl
......@@ -165,7 +165,7 @@ RUN cd /usr/local/src && \
--disable-doxygen-doc \
--enable-optimizations \
--enable-devel-headers \
--with-rocm=/opt/rocm \
--with-rocm=${ROCM_PATH} \
--with-verbs \
--with-dm \
--enable-mt && \
......@@ -186,7 +186,12 @@ RUN git clone ${RIXL_REPO} /opt/rixl && \
ninja install
# Generate RIXL wheel
RUN cd /opt/rixl && mkdir -p /app/install && \
# Exclude libcore and libpull from auditwheel: transitive dependencies
# that are not shipped in the wheel and vary across base images.
RUN cd /opt/rixl && \
sed -i "s/--exclude 'libamdhip64\*'/--exclude 'libamdhip64*' --exclude 'libcore*' --exclude 'libpull*'/" \
contrib/build-wheel.sh && \
mkdir -p /app/install && \
./contrib/build-wheel.sh \
--output-dir /app/install \
--rocm-dir ${ROCM_PATH} \
......@@ -431,6 +436,10 @@ COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-pac
ENV MIOPEN_DEBUG_CONV_DIRECT=0
ENV MIOPEN_DEBUG_CONV_GEMM=0
# Use legacy IPC mode for HSA to avoid GPU memory pinning issues with UCX rocm_ipc
# See: https://github.com/ROCm/rocm-libraries/issues/6266
ENV HSA_ENABLE_IPC_MODE_LEGACY=1
# Source code is used in the `python_only_compile.sh` test
# We hide it inside `src/` so that this source code
# will not be imported by other tests
......
......@@ -311,7 +311,7 @@ async def test_abort_timeout_exits_quickly(wait_for_engine_idle: float):
pytest.fail("Process did not exit after SIGTERM with abort timeout")
exit_time = time.time() - start_time
assert exit_time < 2, f"Default shutdown took too long: {exit_time:.1f}s"
assert exit_time < 2.1, f"Default shutdown took too long: {exit_time:.1f}s"
assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}"
await _assert_children_cleaned_up(child_pids)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Verify that GPU memory is fully released after RixlConnector shutdown on ROCm.
Regression test for ROCm/ucx#33: UCX rocm_ipc transport permanently pinned
GPU memory via hsa_amd_ipc_memory_create during ucp_mem_map, causing
GPU memory to be unrecoverable after engine shutdown.
"""
import gc
import pytest
import torch
from vllm.platforms import current_platform
pytestmark = pytest.mark.skipif(
not current_platform.is_rocm(),
reason="ROCm platform required",
)
def _mb(b: int) -> float:
return b / (1024 * 1024)
def _gpu_snapshot(tag: str, prev_alloc: float = 0.0) -> dict:
"""Print and return current GPU memory stats."""
torch.accelerator.synchronize()
alloc = torch.accelerator.memory_allocated()
reserved = torch.accelerator.memory_reserved()
# mem_get_info is not available on torch.accelerator
try:
drv_free, drv_total = torch.cuda.mem_get_info()
drv_used = drv_total - drv_free
drv_pct = drv_used / drv_total * 100
except Exception:
drv_used = drv_total = drv_pct = 0
alloc_mb = _mb(alloc)
drv_used_mb = _mb(drv_used)
delta = alloc_mb - prev_alloc
print(
f" {tag:<40s} | {alloc_mb:>9.1f} alloc | "
f"{_mb(reserved):>9.1f} rsrvd | "
f"{drv_used_mb:>9.1f} driver ({drv_pct:.1f}%) | "
f"delta {delta:>+9.1f}"
)
return {
"tag": tag,
"alloc_mb": alloc_mb,
"drv_used_mb": drv_used_mb,
"drv_pct": drv_pct,
}
def _full_gpu_cleanup():
"""gc.collect + torch empty_cache, multiple rounds."""
gc.unfreeze()
for _ in range(3):
if gc.collect() == 0:
break
torch.accelerator.empty_cache()
@pytest.mark.parametrize("model_name, sw_size", [("google/gemma-3-1b-it", 512)])
def test_gpu_memory_rixl_hma(model_name, sw_size):
"""Track GPU memory through NixlConnector create/infer/shutdown cycle."""
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
llm_kwargs = {
"model": model_name,
"enforce_eager": True,
"gpu_memory_utilization": 0.5,
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
),
"max_model_len": 2048,
"disable_hybrid_kv_cache_manager": False,
"max_num_batched_tokens": 1024,
"enable_prefix_caching": False,
"block_size": 16,
}
print("\n" + "=" * 90)
print("GPU MEMORY -- RIXL NixlConnector HMA (ROCm)")
print("=" * 90)
gc.collect()
torch.accelerator.empty_cache()
torch.accelerator.reset_peak_memory_stats()
snap0 = _gpu_snapshot("0. baseline", 0.0)
# create + infer
llm = LLM(**llm_kwargs)
snap1 = _gpu_snapshot("1. after LLM()", snap0["alloc_mb"])
llm.generate(
["hi" * 1401],
SamplingParams(
temperature=0.0,
max_tokens=1,
extra_args={
"kv_transfer_params": {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
},
),
)
snap2 = _gpu_snapshot("2. after generate()", snap1["alloc_mb"])
# shutdown + cleanup
print("\n--- shutdown ---")
llm.llm_engine.engine_core.shutdown()
_gpu_snapshot("3. after shutdown()", snap2["alloc_mb"])
del llm
_full_gpu_cleanup()
cleanup_dist_env_and_memory()
_full_gpu_cleanup()
torch._dynamo.reset()
gc.collect()
torch.accelerator.empty_cache()
snap_final = _gpu_snapshot("4. final", snap2["alloc_mb"])
# summary
print("\n" + "=" * 90)
baseline = snap0["alloc_mb"]
final = snap_final["alloc_mb"]
peak = snap2["alloc_mb"]
total_alloc = peak - baseline
print(
f" PyTorch: baseline={baseline:.0f} peak={peak:.0f} "
f"final={final:.0f} "
f"leaked={final - baseline:.0f} MB"
+ (
f" ({(final - baseline) / total_alloc * 100:.1f}%)"
if total_alloc > 0
else ""
)
)
drv_base = snap0["drv_used_mb"]
drv_final = snap_final["drv_used_mb"]
drv_leaked = drv_final - drv_base
print(
f" Driver: baseline={drv_base:.0f} ({snap0['drv_pct']:.1f}%) "
f"peak={snap2['drv_used_mb']:.0f} ({snap2['drv_pct']:.1f}%) "
f"final={drv_final:.0f} ({snap_final['drv_pct']:.1f}%) "
f"leaked={drv_leaked:.0f} MB"
)
print("=" * 90)
# Peak driver memory used above baseline
drv_peak = snap2["drv_used_mb"] - drv_base
leak_pct = (drv_leaked / drv_peak * 100) if drv_peak > 0 else 0
max_leak_pct = 10
assert leak_pct <= max_leak_pct, (
f"{drv_leaked:.0f} MB ({leak_pct:.1f}%) of driver-level GPU memory "
f"not freed after NixlConnector shutdown "
f"(peak allocation: {drv_peak:.0f} MB, threshold: {max_leak_pct}%)"
)
@pytest.mark.parametrize("model_name", ["google/gemma-3-1b-it"])
def test_gpu_memory_no_rixl_baseline(model_name):
"""Same workload without NixlConnector. Comparing driver-level memory
between this and test_gpu_memory_rixl_hma isolates UCX/RIXL impact."""
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
print("\n" + "=" * 90)
print("CONTROL -- same model, no RIXL connector")
print("=" * 90)
gc.collect()
torch.accelerator.empty_cache()
snap0 = _gpu_snapshot("baseline", 0.0)
llm = LLM(
model=model_name,
enforce_eager=True,
gpu_memory_utilization=0.5,
max_model_len=2048,
max_num_batched_tokens=1024,
enable_prefix_caching=False,
block_size=16,
)
_gpu_snapshot("after LLM()", snap0["alloc_mb"])
llm.generate(["hi " * 500], SamplingParams(max_tokens=1))
snap_peak = _gpu_snapshot("after generate()", snap0["alloc_mb"])
llm.llm_engine.engine_core.shutdown()
del llm
_full_gpu_cleanup()
cleanup_dist_env_and_memory()
_full_gpu_cleanup()
torch._dynamo.reset()
gc.collect()
torch.accelerator.empty_cache()
snap_final = _gpu_snapshot("final", snap0["alloc_mb"])
drv_base = snap0["drv_used_mb"]
drv_leaked = snap_final["drv_used_mb"] - drv_base
drv_peak = snap_peak["drv_used_mb"] - drv_base
print(f"\n Driver leaked (no rixl): {drv_leaked:.0f} MB")
print("=" * 90)
leak_pct = (drv_leaked / drv_peak * 100) if drv_peak > 0 else 0
max_leak_pct = 10
assert leak_pct <= max_leak_pct, (
f"{drv_leaked:.0f} MB ({leak_pct:.1f}%) of driver-level GPU memory "
f"not freed after baseline shutdown "
f"(peak allocation: {drv_peak:.0f} MB, threshold: {max_leak_pct}%)"
)
......@@ -87,10 +87,13 @@ class MockSubscriber:
def _wait_for_prefix_cache_reset(llm: LLM) -> None:
"""Wait for async offload transfers to finish so prefix cache can reset.
The GPU-to-CPU offload runs on a CUDA stream asynchronously. While blocks
The GPU-to-CPU offload runs on a CUDA stream asynchronously. While blocks
are still held by the offload worker, ``reset_prefix_cache`` returns
``False``. Retry with a short sleep until it succeeds or we time out.
``False``. Between retries we send a dummy single-token prefill to force
the engine to step, which polls the worker for completed transfers and
frees GPU blocks.
"""
_dummy_params = SamplingParams(max_tokens=1)
deadline = time.monotonic() + _RESET_CACHE_TIMEOUT
while not llm.reset_prefix_cache():
if time.monotonic() > deadline:
......@@ -98,7 +101,13 @@ def _wait_for_prefix_cache_reset(llm: LLM) -> None:
"reset_prefix_cache did not succeed within "
f"{_RESET_CACHE_TIMEOUT}s - async offload may be stuck"
)
time.sleep(0.1)
# Force an engine step so the scheduler polls get_finished()
# and releases GPU blocks held by in-flight async stores.
llm.generate(
[TokensPrompt(prompt_token_ids=[0])],
_dummy_params,
use_tqdm=False,
)
def _latency_test(llm: LLM, subscriber: MockSubscriber):
......
......@@ -45,16 +45,14 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
@lru_cache
def supports_kw(
def _supports_kw(
callable: Callable[..., object],
kw_name: str,
*,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
"""Internal cached implementation of supports_kw."""
params = inspect.signature(callable).parameters
if not params:
return False
......@@ -99,6 +97,29 @@ def supports_kw(
return False
def supports_kw(
callable: Callable[..., object],
kw_name: str,
*,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
# Unwrap bound methods so that the lru_cache key is the underlying
# function, not the instance. Caching bound methods pins the object
# (and all its GPU tensors) for the lifetime of the cache.
if hasattr(callable, "__func__"):
callable = callable.__func__
return _supports_kw(
callable,
kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Mapping[str, object] | None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import os
import queue
import signal
......@@ -574,6 +575,12 @@ class EngineCore:
if self.scheduler:
self.scheduler.shutdown()
# Undo the gc.freeze() from __init__ so that the objects allocated
# during engine startup (model weights, KV caches, etc.) become
# visible to the garbage collector again. Without this, deleting
# the engine in-process (e.g. unit tests) leaks GPU memory.
gc.unfreeze()
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
self.model_executor.profile(is_start, profile_prefix)
......
......@@ -5886,6 +5886,20 @@ class GPUModelRunner(
gc.unfreeze()
gc.collect()
def shutdown(self) -> None:
"""Release GPU tensors (model weights, KV caches, workspace) so that
memory is reclaimable when running in the same process."""
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
from vllm.v1.worker.workspace import reset_workspace_manager
# Calls torch.accelerator.synchronize()
self._cleanup_profiling_kv_cache()
self.compilation_config.static_forward_context.clear()
self.model = None # type: ignore[assignment]
_ROPE_DICT.clear()
reset_workspace_manager()
def _cleanup_profiling_kv_cache(self) -> None:
torch.accelerator.synchronize()
if hasattr(self, "kv_caches") and self.kv_caches:
......
......@@ -1015,6 +1015,11 @@ class Worker(WorkerBase):
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown()
# Release GPU resources held by the model runner so that memory
# can be reclaimed when running in-process
if model_runner := getattr(self, "model_runner", None):
model_runner.shutdown()
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment