Unverified Commit 6f5b460c authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(gms): restore default allocator cleanup before checkpoint (#8049)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 31909ca3
......@@ -4,9 +4,6 @@
"""Shared Dynamo snapshot helpers for checkpoint lifecycle."""
import asyncio
import ctypes
import ctypes.util
import gc
import logging
import os
import signal
......@@ -253,48 +250,3 @@ def reload_snapshot_restore_identity(
os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes"
return get_worker_namespace(), "kubernetes"
def _try_release_memory(label: str) -> None:
"""Force Python GC and glibc malloc_trim to return freed memory to the OS.
Logs RSS before/after so you can see how much memory was actually reclaimable.
"""
pid = os.getpid()
def _get_rss_kb() -> int:
try:
with open(f"/proc/{pid}/status") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1])
except Exception:
pass
return 0
rss_before = _get_rss_kb()
collected = gc.collect()
rss_after_gc = _get_rss_kb()
try:
libc_name = ctypes.util.find_library("c")
if libc_name:
libc = ctypes.CDLL(libc_name)
libc.malloc_trim(0)
except Exception as e:
logger.debug("[MemRelease:%s] malloc_trim failed: %s", label, e)
rss_after_trim = _get_rss_kb()
logger.info(
"[MemRelease:%s] gc.collect freed %d objects, "
"RSS: %.2f MiB -> %.2f MiB (gc) -> %.2f MiB (malloc_trim), "
"reclaimed=%.2f MiB",
label,
collected,
rss_before / 1024,
rss_after_gc / 1024,
rss_after_trim / 1024,
(rss_before - rss_after_trim) / 1024,
)
......@@ -4,16 +4,13 @@
"""Dynamo Snapshot integration for SGLang workers."""
import gc
import logging
import time
import sglang as sgl
from dynamo.common.utils.snapshot import (
CheckpointConfig,
EngineSnapshotController,
_try_release_memory,
)
from dynamo.common.utils.snapshot import CheckpointConfig, EngineSnapshotController
from .request_handlers.handler_base import SGLangEngineQuiesceController
......@@ -61,7 +58,7 @@ async def prepare_snapshot_engine(
f"SGLang engine loaded in {time.time() - start_time:.2f}s (checkpoint mode)"
)
_try_release_memory("after_engine_load")
gc.collect()
snapshot_controller = EngineSnapshotController(
engine=engine,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import gc
import logging
from collections.abc import Callable
from dynamo.common.utils.snapshot import (
CheckpointConfig,
EngineSnapshotController,
_try_release_memory,
)
from dynamo.common.utils.snapshot import CheckpointConfig, EngineSnapshotController
from .args import Config
from .handlers import VllmEngineQuiesceController
......@@ -36,7 +33,7 @@ async def prepare_snapshot_engine(
config.engine_args.enable_sleep_mode = True
engine = setup_vllm_engine(config)
_try_release_memory("after_engine_load")
gc.collect()
snapshot_controller = EngineSnapshotController(
engine=engine,
quiesce_controller=VllmEngineQuiesceController(engine[0]),
......
......@@ -32,13 +32,16 @@ def patch_empty_cache() -> None:
_original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None:
mapping_count = sum(
len(manager.mappings) for manager in get_gms_client_memory_managers()
active_mapping_count = sum(
1
for manager in get_gms_client_memory_managers()
for mapping in manager.mappings.values()
if mapping.handle != 0
)
if mapping_count > 0:
logger.debug(
"[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active",
mapping_count,
if active_mapping_count:
logger.warning(
"[GMS] Skipping torch.cuda.empty_cache() - %d active GMS mappings",
active_mapping_count,
)
return
_original_empty_cache()
......
......@@ -16,6 +16,7 @@ GMS intentionally does not use.
from __future__ import annotations
import gc
import logging
from contextlib import contextmanager
from typing import Optional
......@@ -160,6 +161,8 @@ class GMSMemorySaverImpl:
# abort() drops the current session after unmapping while keeping
# the VA reservation alive for the next resume().
self.allocators[target_tag].abort()
gc.collect()
torch.cuda.empty_cache()
def resume(self, tag: Optional[str] = None) -> None:
for target_tag in _pause_resume_tags(tag):
......
......@@ -12,6 +12,7 @@ Usage:
from __future__ import annotations
import gc
import logging
import sys
from contextlib import nullcontext
......@@ -241,6 +242,9 @@ class GMSWorker(Worker):
"[GMS] No kv_cache manager (shadow mode), skipping kv_cache sleep"
)
gc.collect()
torch.cuda.empty_cache()
free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before
used_bytes = total - free_bytes_after
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment