Unverified Commit 6f5b460c authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(gms): restore default allocator cleanup before checkpoint (#8049)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 31909ca3
...@@ -4,9 +4,6 @@ ...@@ -4,9 +4,6 @@
"""Shared Dynamo snapshot helpers for checkpoint lifecycle.""" """Shared Dynamo snapshot helpers for checkpoint lifecycle."""
import asyncio import asyncio
import ctypes
import ctypes.util
import gc
import logging import logging
import os import os
import signal import signal
...@@ -253,48 +250,3 @@ def reload_snapshot_restore_identity( ...@@ -253,48 +250,3 @@ def reload_snapshot_restore_identity(
os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes" os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes"
return get_worker_namespace(), "kubernetes" return get_worker_namespace(), "kubernetes"
def _try_release_memory(label: str) -> None:
"""Force Python GC and glibc malloc_trim to return freed memory to the OS.
Logs RSS before/after so you can see how much memory was actually reclaimable.
"""
pid = os.getpid()
def _get_rss_kb() -> int:
try:
with open(f"/proc/{pid}/status") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1])
except Exception:
pass
return 0
rss_before = _get_rss_kb()
collected = gc.collect()
rss_after_gc = _get_rss_kb()
try:
libc_name = ctypes.util.find_library("c")
if libc_name:
libc = ctypes.CDLL(libc_name)
libc.malloc_trim(0)
except Exception as e:
logger.debug("[MemRelease:%s] malloc_trim failed: %s", label, e)
rss_after_trim = _get_rss_kb()
logger.info(
"[MemRelease:%s] gc.collect freed %d objects, "
"RSS: %.2f MiB -> %.2f MiB (gc) -> %.2f MiB (malloc_trim), "
"reclaimed=%.2f MiB",
label,
collected,
rss_before / 1024,
rss_after_gc / 1024,
rss_after_trim / 1024,
(rss_before - rss_after_trim) / 1024,
)
...@@ -4,16 +4,13 @@ ...@@ -4,16 +4,13 @@
"""Dynamo Snapshot integration for SGLang workers.""" """Dynamo Snapshot integration for SGLang workers."""
import gc
import logging import logging
import time import time
import sglang as sgl import sglang as sgl
from dynamo.common.utils.snapshot import ( from dynamo.common.utils.snapshot import CheckpointConfig, EngineSnapshotController
CheckpointConfig,
EngineSnapshotController,
_try_release_memory,
)
from .request_handlers.handler_base import SGLangEngineQuiesceController from .request_handlers.handler_base import SGLangEngineQuiesceController
...@@ -61,7 +58,7 @@ async def prepare_snapshot_engine( ...@@ -61,7 +58,7 @@ async def prepare_snapshot_engine(
f"SGLang engine loaded in {time.time() - start_time:.2f}s (checkpoint mode)" f"SGLang engine loaded in {time.time() - start_time:.2f}s (checkpoint mode)"
) )
_try_release_memory("after_engine_load") gc.collect()
snapshot_controller = EngineSnapshotController( snapshot_controller = EngineSnapshotController(
engine=engine, engine=engine,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import gc
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from dynamo.common.utils.snapshot import ( from dynamo.common.utils.snapshot import CheckpointConfig, EngineSnapshotController
CheckpointConfig,
EngineSnapshotController,
_try_release_memory,
)
from .args import Config from .args import Config
from .handlers import VllmEngineQuiesceController from .handlers import VllmEngineQuiesceController
...@@ -36,7 +33,7 @@ async def prepare_snapshot_engine( ...@@ -36,7 +33,7 @@ async def prepare_snapshot_engine(
config.engine_args.enable_sleep_mode = True config.engine_args.enable_sleep_mode = True
engine = setup_vllm_engine(config) engine = setup_vllm_engine(config)
_try_release_memory("after_engine_load") gc.collect()
snapshot_controller = EngineSnapshotController( snapshot_controller = EngineSnapshotController(
engine=engine, engine=engine,
quiesce_controller=VllmEngineQuiesceController(engine[0]), quiesce_controller=VllmEngineQuiesceController(engine[0]),
......
...@@ -32,13 +32,16 @@ def patch_empty_cache() -> None: ...@@ -32,13 +32,16 @@ def patch_empty_cache() -> None:
_original_empty_cache = torch.cuda.empty_cache _original_empty_cache = torch.cuda.empty_cache
def safe_empty_cache() -> None: def safe_empty_cache() -> None:
mapping_count = sum( active_mapping_count = sum(
len(manager.mappings) for manager in get_gms_client_memory_managers() 1
for manager in get_gms_client_memory_managers()
for mapping in manager.mappings.values()
if mapping.handle != 0
) )
if mapping_count > 0: if active_mapping_count:
logger.debug( logger.warning(
"[GMS] Skipping torch.cuda.empty_cache() - %d VMM allocations active", "[GMS] Skipping torch.cuda.empty_cache() - %d active GMS mappings",
mapping_count, active_mapping_count,
) )
return return
_original_empty_cache() _original_empty_cache()
......
...@@ -16,6 +16,7 @@ GMS intentionally does not use. ...@@ -16,6 +16,7 @@ GMS intentionally does not use.
from __future__ import annotations from __future__ import annotations
import gc
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import Optional
...@@ -160,6 +161,8 @@ class GMSMemorySaverImpl: ...@@ -160,6 +161,8 @@ class GMSMemorySaverImpl:
# abort() drops the current session after unmapping while keeping # abort() drops the current session after unmapping while keeping
# the VA reservation alive for the next resume(). # the VA reservation alive for the next resume().
self.allocators[target_tag].abort() self.allocators[target_tag].abort()
gc.collect()
torch.cuda.empty_cache()
def resume(self, tag: Optional[str] = None) -> None: def resume(self, tag: Optional[str] = None) -> None:
for target_tag in _pause_resume_tags(tag): for target_tag in _pause_resume_tags(tag):
......
...@@ -12,6 +12,7 @@ Usage: ...@@ -12,6 +12,7 @@ Usage:
from __future__ import annotations from __future__ import annotations
import gc
import logging import logging
import sys import sys
from contextlib import nullcontext from contextlib import nullcontext
...@@ -241,6 +242,9 @@ class GMSWorker(Worker): ...@@ -241,6 +242,9 @@ class GMSWorker(Worker):
"[GMS] No kv_cache manager (shadow mode), skipping kv_cache sleep" "[GMS] No kv_cache manager (shadow mode), skipping kv_cache sleep"
) )
gc.collect()
torch.cuda.empty_cache()
free_bytes_after, total = torch.cuda.mem_get_info() free_bytes_after, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after - free_bytes_before freed_bytes = free_bytes_after - free_bytes_before
used_bytes = total - free_bytes_after used_bytes = total - free_bytes_after
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment