Unverified Commit d96a2cf1 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

feat: add TRT-LLM sleep/wake integration with GMS (#7575)


Co-authored-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent 119149f2
......@@ -160,6 +160,23 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=True,
help="Disable calling abort() on the TRT-LLM engine when a request is cancelled.",
)
add_argument(
g,
flag_name="--load-format",
env_var="DYN_TRTLLM_LOAD_FORMAT",
default="auto",
help="Model weight loading format passed to TensorRT-LLM (e.g. 'auto', 'gms').",
)
add_argument(
g,
flag_name="--model-loader-extra-config",
env_var="DYN_TRTLLM_MODEL_LOADER_EXTRA_CONFIG",
default="",
help=(
"JSON object passed as extra config to the model loader "
"(e.g. '{\"gms_read_only\": true}')."
),
)
add_argument(
g,
flag_name="--disaggregation-mode",
......@@ -453,6 +470,8 @@ class DynamoTrtllmConfig(ConfigBase):
override_engine_args: str
publish_events_and_metrics: bool
disable_request_abort: bool
load_format: str
model_loader_extra_config: str
guided_decoding_backend: Optional[str] = None
disaggregation_mode: DisaggregationMode
......
......@@ -58,6 +58,98 @@ if TYPE_CHECKING:
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class TRTLLMEngineQuiesceController:
"""Adapts TRT-LLM sleep/wake to the standard quiesce controller interface.
Two memory domains: KV cache via TRT-LLM collective_rpc, weights via GMS.
"""
def __init__(self, engine: TensorRTLLMEngine):
self._engine = engine
self._is_quiesced = False
@property
def is_quiesced(self) -> bool:
return self._is_quiesced
async def quiesce(self, tags: list[str] | None = None) -> bool:
if self._is_quiesced:
return False
tags = tags or ["kv_cache", "weights"]
if "kv_cache" in tags:
self._collective_rpc("sleep", ["kv_cache"])
if "weights" in tags:
self._release_gms_weights()
self._is_quiesced = True
return True
async def resume(self, tags: list[str] | None = None) -> bool:
if not self._is_quiesced:
return False
tags = tags or ["kv_cache", "weights"]
if "weights" in tags:
self._restore_gms_weights()
if "kv_cache" in tags:
self._collective_rpc("wakeup", ["kv_cache"])
return True
def mark_resumed(self) -> None:
self._is_quiesced = False
def _collective_rpc(self, method: str, rpc_tags: list[str]) -> None:
"""Call TRT-LLM collective_rpc for KV cache sleep/wake."""
rpc = getattr(self._engine.llm, "_collective_rpc", None)
if rpc is None:
logger.warning(
"TRT-LLM does not expose _collective_rpc; skipping %s", method
)
return
try:
rpc(method, args=(rpc_tags,), kwargs={}, non_block=False)
except Exception:
if method != "wakeup":
raise
# Some TRT-LLM versions use "wake_up" instead of "wakeup"
rpc("wake_up", args=(rpc_tags,), kwargs={}, non_block=False)
@staticmethod
def _release_gms_weights() -> None:
"""Release GMS-managed weight memory."""
try:
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
)
except ImportError:
return
manager = get_gms_client_memory_manager("weights")
if manager is None:
return
manager.unmap_all_vas()
manager.abort()
torch.cuda.synchronize()
torch.cuda.empty_cache()
@staticmethod
def _restore_gms_weights() -> None:
"""Restore GMS-managed weight memory."""
try:
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
)
from gpu_memory_service.integrations.trtllm.model_loader import (
get_gms_lock_mode,
)
except ImportError:
return
manager = get_gms_client_memory_manager("weights")
if manager is None or not manager.is_unmapped:
return
manager.connect(get_gms_lock_mode())
manager.remap_all_vas()
class _Abortable(Protocol):
"""Structural type for objects that support abort(). Satisfied by both
......@@ -127,6 +219,7 @@ class RequestHandlerConfig:
metrics_collector: Optional["MetricsCollector"] = None
kv_block_size: int = 32
shutdown_event: Optional[asyncio.Event] = None
generate_endpoint: Optional[Any] = None
encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB
disable_request_abort: bool = True
additional_metrics: Optional["AdditionalMetricsCollector"] = None
......@@ -160,10 +253,19 @@ class HandlerBase(BaseGenerativeHandler):
self.runtime = config.runtime
self.kv_block_size: int = config.kv_block_size
self.shutdown_event = config.shutdown_event
self.generate_endpoint = config.generate_endpoint
self.disable_request_abort = config.disable_request_abort
self.additional_metrics = config.additional_metrics
self.max_seq_len = config.max_seq_len
self.disagg_machine_id = config.disagg_machine_id
# Sleep/wake state
self._quiesce_lock = asyncio.Lock()
self._inflight_lock = asyncio.Lock()
self._inflight_requests = 0
self._no_inflight_requests = asyncio.Event()
self._no_inflight_requests.set()
self._quiesce_controller = TRTLLMEngineQuiesceController(config.engine)
self._reject_new_requests = False
def check_error(self, result: dict) -> bool:
"""
......@@ -176,6 +278,96 @@ class HandlerBase(BaseGenerativeHandler):
result["finish_reason"] == "stop" or result["finish_reason"] == "error"
)
# ------------------------------------------------------------------
# In-flight request tracking (used by sleep/wake)
# ------------------------------------------------------------------
async def _set_reject_new_requests(self, reject: bool) -> None:
async with self._inflight_lock:
self._reject_new_requests = reject
async def _mark_request_started(self) -> bool:
async with self._inflight_lock:
if self._reject_new_requests:
return False
self._inflight_requests += 1
self._no_inflight_requests.clear()
return True
async def _mark_request_finished(self) -> None:
async with self._inflight_lock:
if self._inflight_requests == 0:
return
self._inflight_requests -= 1
if self._inflight_requests == 0:
self._no_inflight_requests.set()
async def _wait_for_inflight_requests(self, timeout_s: float) -> None:
try:
await asyncio.wait_for(self._no_inflight_requests.wait(), timeout_s)
except asyncio.TimeoutError as exc:
async with self._inflight_lock:
inflight = self._inflight_requests
raise RuntimeError(
f"Timed out waiting for {inflight} in-flight request(s) to finish"
) from exc
# ------------------------------------------------------------------
# Sleep / wake public API (delegates to TRTLLMEngineQuiesceController)
# ------------------------------------------------------------------
async def release_memory_occupation(self, body: dict) -> dict:
"""Release GPU memory: unregister endpoint, drain requests, quiesce engine."""
body = body or {}
tags = body.get("tags")
async with self._quiesce_lock:
if self._quiesce_controller.is_quiesced:
return {"status": "ok", "message": "Memory already released"}
try:
await self._set_reject_new_requests(True)
if self.generate_endpoint is not None:
await self.generate_endpoint.unregister_endpoint_instance()
timeout_s = float(body.get("timeout_s", 30.0))
await self._wait_for_inflight_requests(timeout_s)
await self._quiesce_controller.quiesce(tags)
return {"status": "ok", "message": "Memory released"}
except Exception as exc:
logger.error("release_memory_occupation failed: %s", exc)
# Rollback: TRT-LLM has no pause_generation(), so we
# manually unregistered the endpoint and set reject flag
# above. Restore both on failure.
if self.generate_endpoint is not None:
await self.generate_endpoint.register_endpoint_instance()
await self._set_reject_new_requests(False)
return {"status": "error", "message": str(exc)}
async def resume_memory_occupation(self, body: dict) -> dict:
"""Restore GPU memory: resume engine, re-register endpoint."""
body = body or {}
tags = body.get("tags")
async with self._quiesce_lock:
if not self._quiesce_controller.is_quiesced:
return {"status": "ok", "message": "Memory already resumed"}
try:
await self._quiesce_controller.resume(tags)
if self.generate_endpoint is not None:
await self.generate_endpoint.register_endpoint_instance()
await self._set_reject_new_requests(False)
self._quiesce_controller.mark_resumed()
return {"status": "ok", "message": "Memory resumed"}
except Exception as exc:
logger.error("resume_memory_occupation failed: %s", exc)
return {"status": "error", "message": str(exc)}
@staticmethod
def _extract_logprobs(
output, num_output_tokens_so_far: int
......@@ -719,6 +911,31 @@ class HandlerBase(BaseGenerativeHandler):
context: Context,
embeddings: Optional[Union[torch.Tensor, dict]] = None,
ep_disaggregated_params: Optional[DisaggregatedParams] = None,
) -> AsyncGenerator[dict, None]:
"""Track in-flight count, reject during sleep, then delegate to implementation."""
started = await self._mark_request_started()
if not started:
yield {
"finish_reason": {
"error": "Worker is temporarily rejecting new requests"
},
"token_ids": [],
}
return
try:
async for chunk in self._generate_locally_impl(
request, context, embeddings, ep_disaggregated_params
):
yield chunk
finally:
await self._mark_request_finished()
async def _generate_locally_impl(
self,
request: dict,
context: Context,
embeddings: Optional[Union[torch.Tensor, dict]] = None,
ep_disaggregated_params: Optional[DisaggregatedParams] = None,
) -> AsyncGenerator[dict, None]:
"""
Generate responses based on the disaggregation mode in the request.
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for TRT-LLM sleep/wake handler logic.
Tests cover in-flight tracking, reject-flag, and sleep/wake delegation to
TRTLLMEngineQuiesceController without requiring a real GPU or TRT-LLM engine.
"""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
import torch
if not torch.cuda.is_available():
pytest.skip(
"Skipping: CUDA not available (tensorrt_llm import requires GPU).",
allow_module_level=True,
)
from dynamo.trtllm.request_handlers.handler_base import HandlerBase
pytestmark = [
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.pre_merge,
pytest.mark.gpu_0,
]
# ---------------------------------------------------------------------------
# Test fixture helpers
# ---------------------------------------------------------------------------
class _ConcreteHandler(HandlerBase):
async def generate(self, request, context):
yield {}
def _make_handler() -> _ConcreteHandler:
"""Create a HandlerBase subclass with mocked quiesce controller and endpoint."""
handler = _ConcreteHandler.__new__(_ConcreteHandler)
handler.generate_endpoint = SimpleNamespace(
unregister_endpoint_instance=AsyncMock(),
register_endpoint_instance=AsyncMock(),
)
handler._quiesce_lock = asyncio.Lock()
handler._inflight_lock = asyncio.Lock()
handler._inflight_requests = 0
handler._no_inflight_requests = asyncio.Event()
handler._no_inflight_requests.set()
handler._reject_new_requests = False
# Mock the quiesce controller that release/resume delegate to
handler._quiesce_controller = MagicMock()
handler._quiesce_controller.is_quiesced = False
handler._quiesce_controller.quiesce = AsyncMock(return_value=True)
handler._quiesce_controller.resume = AsyncMock(return_value=True)
handler._quiesce_controller.mark_resumed = MagicMock()
return handler
# ---------------------------------------------------------------------------
# In-flight tracking
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mark_request_started_respects_reject_flag():
handler = _make_handler()
await handler._set_reject_new_requests(True)
assert not await handler._mark_request_started()
assert handler._inflight_requests == 0
@pytest.mark.asyncio
async def test_mark_request_started_and_finished():
handler = _make_handler()
assert await handler._mark_request_started()
assert handler._inflight_requests == 1
assert not handler._no_inflight_requests.is_set()
await handler._mark_request_finished()
assert handler._inflight_requests == 0
assert handler._no_inflight_requests.is_set()
@pytest.mark.asyncio
async def test_mark_request_finished_is_idempotent():
handler = _make_handler()
# Extra call when count is already 0 must not underflow
await handler._mark_request_finished()
assert handler._inflight_requests == 0
# ---------------------------------------------------------------------------
# release_memory_occupation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_release_is_noop_when_already_quiesced():
handler = _make_handler()
handler._quiesce_controller.is_quiesced = True
result = await handler.release_memory_occupation({})
assert result["status"] == "ok"
assert "already released" in result["message"]
handler._quiesce_controller.quiesce.assert_not_called()
@pytest.mark.asyncio
async def test_release_returns_error_for_non_numeric_timeout():
handler = _make_handler()
result = await handler.release_memory_occupation({"timeout_s": "bad"})
assert result["status"] == "error"
@pytest.mark.asyncio
async def test_release_delegates_to_quiesce_controller():
handler = _make_handler()
result = await handler.release_memory_occupation({})
assert result["status"] == "ok"
handler._quiesce_controller.quiesce.assert_awaited_once_with(None)
@pytest.mark.asyncio
async def test_release_passes_tags_to_controller():
handler = _make_handler()
result = await handler.release_memory_occupation({"tags": ["weights"]})
assert result["status"] == "ok"
handler._quiesce_controller.quiesce.assert_awaited_once_with(["weights"])
@pytest.mark.asyncio
async def test_release_unregisters_endpoint_and_restores_on_error():
handler = _make_handler()
handler._quiesce_controller.quiesce = AsyncMock(
side_effect=RuntimeError("engine error")
)
result = await handler.release_memory_occupation({})
assert result["status"] == "error"
handler.generate_endpoint.unregister_endpoint_instance.assert_called_once()
handler.generate_endpoint.register_endpoint_instance.assert_called_once()
assert not handler._reject_new_requests
# ---------------------------------------------------------------------------
# resume_memory_occupation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_resume_is_noop_when_not_quiesced():
handler = _make_handler()
result = await handler.resume_memory_occupation({})
assert result["status"] == "ok"
assert "already resumed" in result["message"]
handler._quiesce_controller.resume.assert_not_called()
@pytest.mark.asyncio
async def test_release_and_resume_round_trip():
handler = _make_handler()
release = await handler.release_memory_occupation({})
assert release["status"] == "ok"
# After release, controller reports quiesced
handler._quiesce_controller.is_quiesced = True
resume = await handler.resume_memory_occupation({})
assert resume["status"] == "ok"
handler._quiesce_controller.resume.assert_awaited_once()
handler._quiesce_controller.mark_resumed.assert_called_once()
assert not handler._reject_new_requests
handler.generate_endpoint.register_endpoint_instance.assert_called_once()
@pytest.mark.asyncio
async def test_resume_passes_tags_to_controller():
handler = _make_handler()
handler._quiesce_controller.is_quiesced = True
result = await handler.resume_memory_occupation({"tags": ["kv_cache"]})
assert result["status"] == "ok"
handler._quiesce_controller.resume.assert_awaited_once_with(["kv_cache"])
# ---------------------------------------------------------------------------
# generate_locally inflight guard
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_generate_locally_rejected_when_sleeping():
handler = _make_handler()
handler._reject_new_requests = True
chunks = []
ctx = MagicMock()
async for chunk in handler.generate_locally({"token_ids": []}, ctx):
chunks.append(chunk)
assert len(chunks) == 1
assert "error" in str(chunks[0].get("finish_reason", ""))
......@@ -22,7 +22,11 @@ from tensorrt_llm.llmapi import (
SchedulerConfig,
)
from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.llmapi.llm_args import TOKENIZER_ALIASES, KvCacheConnectorConfig
from tensorrt_llm.llmapi.llm_args import (
TOKENIZER_ALIASES,
KvCacheConnectorConfig,
LoadFormat,
)
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.metrics import MetricsCollector
......@@ -127,6 +131,42 @@ def _warn_override_collisions(target: dict, source: dict, path: str = "") -> Non
)
def _parse_model_loader_extra_config(raw: object) -> dict[str, object]:
"""Parse --model-loader-extra-config into a dict. Accepts a dict or a JSON string."""
if raw is None or raw == "":
return {}
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError as exc:
raise ValueError(
f"Invalid JSON in --model-loader-extra-config: {exc}"
) from exc
if not isinstance(parsed, dict):
raise ValueError("--model-loader-extra-config must decode to a JSON object")
return parsed
raise ValueError(
"--model-loader-extra-config must be a JSON object string or a dict"
)
def _register_memory_routes(runtime, handler) -> None:
runtime.register_engine_route(
"release_memory_occupation",
handler.release_memory_occupation,
)
runtime.register_engine_route(
"resume_memory_occupation",
handler.resume_memory_occupation,
)
logging.info(
"Registered engine routes: "
"/engine/release_memory_occupation, /engine/resume_memory_occupation"
)
async def init_llm_worker(
runtime: DistributedRuntime,
config: Config,
......@@ -187,6 +227,40 @@ async def init_llm_worker(
)
kv_connector_config = build_kv_connector_config(config)
try:
model_loader_extra_config = _parse_model_loader_extra_config(
config.model_loader_extra_config
)
except ValueError as exc:
logging.error("%s", exc)
sys.exit(1)
if config.load_format == "gms":
try:
from gpu_memory_service.integrations.trtllm import setup_gms
except ImportError as exc:
raise RuntimeError(
"gpu-memory-service is required for --load-format gms. "
"Install or update the package."
) from exc
setup_gms(model_loader_extra_config)
logging.info(
"TRT-LLM GMS integration enabled (extra=%s)", model_loader_extra_config
)
# Resolve load_format for engine args. GMS patches are active regardless;
# fall back to "auto" if TRT-LLM doesn't recognise "gms" as a LoadFormat.
engine_load_format = config.load_format
if config.load_format == "gms":
try:
LoadFormat(config.load_format)
except (ValueError, KeyError):
logging.warning(
"TensorRT-LLM does not recognise load_format='gms'; "
"using 'auto' while GMS patches remain active."
)
engine_load_format = "auto"
arg_map = {
"model": model_path,
"scheduler_config": scheduler_config,
......@@ -209,6 +283,16 @@ async def init_llm_worker(
"kv_connector_config": kv_connector_config,
}
arg_map["load_format"] = engine_load_format
# Enable sleep_config when GMS manages weights — required for GMS
# unmap/remap. Conditional because SleepConfig contains unpicklable
# lambdas that break MPI-based multi-rank distribution.
if config.load_format == "gms":
from tensorrt_llm.llmapi.llm_args import SleepConfig
arg_map["sleep_config"] = SleepConfig()
# Add guided decoding backend if specified
if config.guided_decoding_backend is not None:
arg_map["guided_decoding_backend"] = config.guided_decoding_backend
......@@ -517,6 +601,7 @@ async def init_llm_worker(
disaggregation_mode=config.disaggregation_mode,
encode_client=encode_client,
multimodal_processor=multimodal_processor,
generate_endpoint=endpoint,
connector=connector,
runtime=runtime, # Pass runtime for graceful shutdown
metrics_collector=metrics_collector,
......@@ -593,6 +678,8 @@ async def init_llm_worker(
) as publisher:
handler_config.publisher = publisher
handler = RequestHandlerFactory().get_request_handler(handler_config)
if config.load_format == "gms":
_register_memory_routes(runtime, handler)
encoder_cache = getattr(handler, "_encoder_cache", None)
if encoder_cache is not None:
......@@ -602,7 +689,6 @@ async def init_llm_worker(
model_name=model_name_for_metrics,
component_name=config.component,
)
await endpoint.serve_endpoint(
handler.generate,
metrics_labels=metrics_labels,
......@@ -614,6 +700,8 @@ async def init_llm_worker(
consolidator_publisher.shutdown()
else:
handler = RequestHandlerFactory().get_request_handler(handler_config)
if config.load_format == "gms":
_register_memory_routes(runtime, handler)
await endpoint.serve_endpoint(
handler.generate, health_check_payload=health_check_payload
)
......@@ -98,7 +98,7 @@ trtllm:
runtime_image_tag: 26.02-cuda13.1-runtime-ubuntu24.04
nixl_ref: 0.10.1
enable_media_ffmpeg: "false"
enable_gpu_memory_service: "false"
enable_gpu_memory_service: "true"
enable_kvbm: "true"
python_version: "3.12"
index_url: https://pypi.nvidia.com/
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service integration for TensorRT-LLM.
Usage:
import json
from gpu_memory_service.integrations.trtllm import setup_gms
if config.load_format == "gms":
raw = config.model_loader_extra_config
extra = json.loads(raw) if isinstance(raw, str) else (raw or None)
setup_gms(extra)
"""
from __future__ import annotations
import logging
from typing import Any
from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import (
get_gms_lock_mode as _resolve_lock_mode,
)
from gpu_memory_service.integrations.trtllm.model_loader import (
get_gms_lock_mode,
patch_model_loader,
set_gms_enabled,
set_gms_lock_mode,
)
logger = logging.getLogger(__name__)
__all__ = ["setup_gms", "get_gms_lock_mode"]
def setup_gms(model_loader_extra_config: dict[str, Any] | None = None) -> None:
"""Set up GMS integration for TensorRT-LLM. Call once before creating the engine."""
extra = model_loader_extra_config or {}
lock_mode = _resolve_lock_mode(extra)
set_gms_enabled(True)
set_gms_lock_mode(lock_mode)
patch_empty_cache()
patch_model_loader()
logger.info("[GMS] TensorRT-LLM integration enabled (mode=%s)", lock_mode)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""TensorRT-LLM model loader patches for GPU Memory Service integration.
This module patches TensorRT-LLM's ModelLoader to load weights through GMS,
enabling VA-stable weight sharing and sleep/wake with memory release.
Two modes:
- RW (first loader): loads weights from disk, allocates via GMS pool, commits.
- RO (subsequent loaders): materializes model tensors from the committed layout.
"""
from __future__ import annotations
import copy
import logging
from typing import TYPE_CHECKING
import torch
from gpu_memory_service.client.torch.allocator import (
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common.utils import finalize_gms_write
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
logger = logging.getLogger(__name__)
_model_loader_patched = False
_gms_enabled = False
_gms_lock_mode = RequestedLockType.RW_OR_RO
_last_imported_weights_bytes: int = 0
def set_gms_enabled(enabled: bool) -> None:
global _gms_enabled
_gms_enabled = enabled
def set_gms_lock_mode(mode: RequestedLockType) -> None:
global _gms_lock_mode
_gms_lock_mode = mode
def get_gms_lock_mode() -> RequestedLockType:
return _gms_lock_mode
def get_imported_weights_bytes() -> int:
"""Return total bytes of weights imported/published by the last model load."""
return _last_imported_weights_bytes
def patch_model_loader() -> None:
"""Patch TensorRT-LLM's ModelLoader to load weights through GMS.
Idempotent — safe to call multiple times.
"""
global _model_loader_patched
if _model_loader_patched:
return
import tensorrt_llm._torch.pyexecutor.model_loader as _trt_loader
_original_load = _trt_loader.ModelLoader.load
_original_get_rank_model_storage = _trt_loader.get_rank_model_storage
def patched_get_rank_model_storage(model) -> int:
imported = get_imported_weights_bytes()
if imported > 0:
return imported
return int(_original_get_rank_model_storage(model))
def patched_load(self, checkpoint_dir: str, checkpoint_loader):
if not _gms_enabled:
return _original_load(self, checkpoint_dir, checkpoint_loader)
return _gms_load(
self=self,
checkpoint_dir=checkpoint_dir,
checkpoint_loader=checkpoint_loader,
original_load=_original_load,
)
_trt_loader.get_rank_model_storage = patched_get_rank_model_storage
_trt_loader.ModelLoader.load = patched_load
_model_loader_patched = True
logger.info("[GMS] Patched TensorRT-LLM ModelLoader.load")
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _gms_load(self, checkpoint_dir: str, checkpoint_loader, original_load):
"""Route to RW (write) or RO (read) load path based on granted lock type."""
# Neutralize TRT-LLM's model_weights_memory_tag to prevent its VMM scope
# from overriding the GMS allocator. When sleep_config is set, TRT-LLM
# wraps allocation in virtual_memory_scope(model_weights) — a nested scope
# that would steal allocations away from GMS's gms_use_mem_pool.
saved_tag = getattr(self, "model_weights_memory_tag", None)
self.model_weights_memory_tag = None
device_index = torch.cuda.current_device()
gms_client = get_or_create_gms_client_memory_manager(
get_socket_path(device_index, "weights"),
device_index,
mode=_gms_lock_mode,
tag="weights",
)
try:
return _gms_load_inner(
self,
gms_client,
device_index,
checkpoint_dir,
checkpoint_loader,
original_load,
)
finally:
self.model_weights_memory_tag = saved_tag
def _gms_load_inner(
self, gms_client, device_index, checkpoint_dir, checkpoint_loader, original_load
):
if gms_client.granted_lock_type == GrantedLockType.RO:
return _load_ro(
self=self,
checkpoint_dir=checkpoint_dir,
checkpoint_loader=checkpoint_loader,
gms_client=gms_client,
device_index=device_index,
)
return _load_rw(
self=self,
checkpoint_dir=checkpoint_dir,
checkpoint_loader=checkpoint_loader,
gms_client=gms_client,
device_index=device_index,
original_load=original_load,
)
def _load_rw(
self, checkpoint_dir, checkpoint_loader, gms_client, device_index, original_load
):
"""RW path: load weights from disk into the GMS memory pool, then commit."""
global _last_imported_weights_bytes
target_device = torch.device("cuda", device_index)
with gms_use_mem_pool("weights", target_device):
model, moe_load_balancer = original_load(
self, checkpoint_dir, checkpoint_loader
)
_move_untracked_params(model, gms_client, target_device)
torch.cuda.empty_cache()
_last_imported_weights_bytes = finalize_gms_write(gms_client, model)
logger.info(
"[GMS] TRT-LLM RW: published %.2f GiB",
_last_imported_weights_bytes / (1 << 30),
)
return model, moe_load_balancer
def _load_ro(self, checkpoint_dir, checkpoint_loader, gms_client, device_index):
"""RO path: skip disk I/O, materialize tensors from the committed GMS layout."""
global _last_imported_weights_bytes
from tensorrt_llm._torch.models import AutoModelForCausalLM
from tensorrt_llm._torch.models.modeling_utils import MetaInitMode, timing
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import (
MoeLoadBalancer,
maybe_create_moe_load_balancer,
)
config = self._load_and_validate_config(checkpoint_dir, checkpoint_loader)
with (
timing("Model init total"),
maybe_create_moe_load_balancer(config, self.mapping) as moe_load_balancer,
):
try:
with MetaInitMode():
model = AutoModelForCausalLM.from_config(copy.deepcopy(config))
except Exception as exc:
raise RuntimeError(
"GMS RO path requires successful MetaInitMode model construction"
) from exc
# Some models register cross-layer references like next_attn here.
if hasattr(model, "post_load_weights"):
model.post_load_weights()
materialize_module_from_gms(gms_client, model, device_index=device_index)
_last_imported_weights_bytes = int(gms_client.total_bytes)
logger.info(
"[GMS] TRT-LLM RO: imported %.2f GiB",
_last_imported_weights_bytes / (1 << 30),
)
for module in model.modules():
if hasattr(module, "post_load_weights") and not getattr(
module, "_weights_removed", False
):
module.post_load_weights()
if isinstance(moe_load_balancer, MoeLoadBalancer):
moe_load_balancer.register_weight_slots_after_to_cuda()
moe_load_balancer.finalize_model()
torch.cuda.current_stream().synchronize()
return model, moe_load_balancer
def _ptr_in_gms(gms_client: "GMSClientMemoryManager", ptr: int) -> bool:
"""Return True if the given CUDA VA is within any active GMS mapping."""
for va, mapping in gms_client.mappings.items():
if va <= ptr < va + mapping.aligned_size:
return True
return False
def _move_untracked_params(
model: torch.nn.Module,
gms_client: "GMSClientMemoryManager",
target_device: torch.device,
) -> None:
"""Move CUDA parameters that were allocated outside the GMS pool into it.
TRT-LLM may allocate some parameters outside the pluggable-allocator scope.
This ensures all weight tensors end up tracked by GMS before we commit.
"""
from gpu_memory_service.client.torch.module import _iter_module_tensors
from gpu_memory_service.client.torch.tensor import _tensor_from_pointer
device_index = (
torch.cuda.current_device()
if target_device.index is None
else int(target_device.index)
)
seen: set[int] = set()
with torch.no_grad():
for _name, tensor, tensor_type in _iter_module_tensors(model):
if tensor_type != "parameter" or tensor is None or not tensor.is_cuda:
continue
storage_ptr = tensor.storage().data_ptr()
if storage_ptr in seen:
continue
seen.add(storage_ptr)
if _ptr_in_gms(gms_client, int(tensor.data_ptr())):
continue
# Allocate a new mapping and copy the tensor into it
nbytes = _storage_nbytes(tensor)
base_va = gms_client.create_mapping(size=nbytes, tag="weights")
replacement = _tensor_from_pointer(
int(base_va),
list(tensor.shape),
list(tensor.stride()),
tensor.dtype,
device_index,
)
replacement.copy_(tensor)
tensor.data = replacement
def _storage_nbytes(tensor: torch.Tensor) -> int:
if tensor.numel() == 0:
return 0
element_size = int(tensor.element_size())
shape = list(tensor.shape)
stride = list(tensor.stride())
if not shape:
return element_size
max_offset = sum(
abs(int(s)) * (int(d) - 1)
for s, d in zip(stride, shape, strict=True)
if int(d) > 1
)
return int((max_offset + 1) * element_size)
......@@ -77,6 +77,7 @@ setup(
"gpu_memory_service.integrations",
"gpu_memory_service.integrations.common",
"gpu_memory_service.integrations.sglang",
"gpu_memory_service.integrations.trtllm",
"gpu_memory_service.integrations.vllm",
],
package_dir={
......@@ -93,6 +94,7 @@ setup(
"gpu_memory_service.integrations": "integrations",
"gpu_memory_service.integrations.common": "integrations/common",
"gpu_memory_service.integrations.sglang": "integrations/sglang",
"gpu_memory_service.integrations.trtllm": "integrations/trtllm",
"gpu_memory_service.integrations.vllm": "integrations/vllm",
},
package_data={
......
......@@ -9,6 +9,7 @@ import json
import logging
import os
import sys
import time
from abc import ABC, abstractmethod
from contextlib import ExitStack
......@@ -34,6 +35,23 @@ def get_gpu_memory_used(device: int = 0) -> int:
pynvml.nvmlShutdown()
def wait_for_memory_drop(
baseline_bytes: int,
*,
timeout_s: float = 30.0,
poll_interval_s: float = 0.5,
) -> int:
"""Poll until GPU memory drops below *baseline_bytes*, then return current usage."""
deadline = time.monotonic() + timeout_s
current = get_gpu_memory_used()
while time.monotonic() < deadline:
if current < baseline_bytes:
return current
time.sleep(poll_interval_s)
current = get_gpu_memory_used()
return current
class GMSProcessManager:
"""Start the shared GMS daemons and frontend for one test scenario."""
......@@ -43,10 +61,12 @@ class GMSProcessManager:
engine_cls,
*,
read_only_weights: bool = False,
tags: tuple[str, ...] = ("weights", "kv_cache"),
):
self._request = request
self._engine_cls = engine_cls
self._read_only_weights = read_only_weights
self._tags = tags
self._stack: ExitStack | None = None
self.frontend_port: int | None = None
self.weights_gms = None
......@@ -57,8 +77,14 @@ class GMSProcessManager:
def __enter__(self):
stack = ExitStack()
try:
self.weights_gms = stack.enter_context(GMSServer(device=0, tag="weights"))
self.kv_cache_gms = stack.enter_context(GMSServer(device=0, tag="kv_cache"))
if "weights" in self._tags:
self.weights_gms = stack.enter_context(
GMSServer(device=0, tag="weights")
)
if "kv_cache" in self._tags:
self.kv_cache_gms = stack.enter_context(
GMSServer(device=0, tag="kv_cache")
)
frontend = stack.enter_context(
DynamoFrontendProcess(
self._request,
......@@ -306,6 +332,96 @@ class VLLMWithGMSProcess(GMSEngineProcess):
return {"level": 2}
class TRTLLMWithGMSProcess(GMSEngineProcess):
"""TensorRT-LLM engine with GMS weights + sleep/wake enabled."""
quiesce_route = "release_memory_occupation"
resume_route = "resume_memory_occupation"
# Override via environment variables for CI or custom setups.
TRTLLM_GMS_MODEL_NAME = os.environ.get(
"TRTLLM_GMS_MODEL_NAME", FAULT_TOLERANCE_MODEL_NAME
)
TRTLLM_GMS_FREE_GPU_MEMORY_FRACTION = os.environ.get(
"TRTLLM_GMS_FREE_GPU_MEMORY_FRACTION", "0.9"
)
TRTLLM_GMS_MAX_SEQ_LEN = os.environ.get("TRTLLM_GMS_MAX_SEQ_LEN", "256")
TRTLLM_GMS_MAX_NUM_TOKENS = os.environ.get("TRTLLM_GMS_MAX_NUM_TOKENS", "256")
TRTLLM_GMS_OVERRIDE_ENGINE_ARGS = os.environ.get(
"TRTLLM_GMS_OVERRIDE_ENGINE_ARGS", ""
)
def __init__(
self,
request,
frontend_port: int,
*,
engine_id: str,
read_only_weights: bool = False,
override_engine_args: str | None = None,
):
reserved_ports = allocate_ports(1)
self._override_engine_args = override_engine_args
try:
super().__init__(
request,
engine_id,
reserved_ports[0],
frontend_port,
reserved_ports,
read_only_weights=read_only_weights,
)
except Exception:
deallocate_ports(reserved_ports)
raise
def env_updates(self) -> dict[str, str]:
env = {
"CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES", "0"),
"TLLM_WORKER_USE_SINGLE_PROCESS": "1",
"MPI4PY_MPIABI": "openmpi",
"OMPI_MCA_coll_ucc_enable": "0",
}
venv = os.environ.get("VIRTUAL_ENV")
if venv:
venv_lib = os.path.join(venv, "lib")
existing = os.environ.get("LD_LIBRARY_PATH", "")
env["LD_LIBRARY_PATH"] = f"{venv_lib}:{existing}" if existing else venv_lib
return env
def command(self) -> list[str]:
command = [
sys.executable,
"-m",
"dynamo.trtllm",
"--model",
self.TRTLLM_GMS_MODEL_NAME,
"--gpus-per-node",
"1",
"--load-format",
"gms",
"--free-gpu-memory-fraction",
self.TRTLLM_GMS_FREE_GPU_MEMORY_FRACTION,
"--max-seq-len",
self.TRTLLM_GMS_MAX_SEQ_LEN,
"--max-num-tokens",
self.TRTLLM_GMS_MAX_NUM_TOKENS,
]
effective_override = self._override_engine_args
if effective_override is None:
effective_override = self.TRTLLM_GMS_OVERRIDE_ENGINE_ARGS
if effective_override:
command.extend(["--override-engine-args", effective_override])
extra_config = self.model_loader_extra_config()
if extra_config is not None:
command.extend(["--model-loader-extra-config", extra_config])
return command
def quiesce_payload(self) -> dict:
return {}
class SGLangWithGMSProcess(GMSEngineProcess):
quiesce_route = "release_memory_occupation"
resume_route = "resume_memory_occupation"
......
......@@ -226,3 +226,31 @@ def assert_kv_history(
]
assert len(clear_counts) >= cleared_layouts
assert all(count > 0 for count in clear_counts[:cleared_layouts])
def wait_for_weights_state(
weights_gms,
expected_state,
*,
min_ro_sessions: int = 0,
expected_hash: str | None = None,
timeout: float = 30.0,
):
"""Poll until the weights GMS daemon reaches *expected_state*."""
deadline = time.monotonic() + timeout
while True:
ws = weights_gms.get_runtime_state()
if (
ws.state == expected_state
and ws.allocation_count > 0
and ws.memory_layout_hash
and ws.ro_session_count >= min_ro_sessions
and (expected_hash is None or ws.memory_layout_hash == expected_hash)
):
return ws
if time.monotonic() > deadline:
raise TimeoutError(
f"Weights: state={ws.state} (want {expected_state}), "
f"allocs={ws.allocation_count}, hash={ws.memory_layout_hash}"
)
time.sleep(0.1)
......@@ -6,10 +6,12 @@ from __future__ import annotations
import logging
import pytest
from gpu_memory_service.server.fsm import ServerState
from tests.gpu_memory_service.common.runtime import (
GMSProcessManager,
SGLangWithGMSProcess,
TRTLLMWithGMSProcess,
VLLMWithGMSProcess,
get_gpu_memory_used,
)
......@@ -20,6 +22,7 @@ from tests.gpu_memory_service.flow_assertions import (
assert_weights_published_once,
quiesce_engine,
wait_for_resumed_layout,
wait_for_weights_state,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
......@@ -137,3 +140,110 @@ def test_gms_basic_quiesce_resume_sglang(
predownload_models,
):
_run_quiesce_resume_test(request, SGLangWithGMSProcess)
# ---------------------------------------------------------------------------
# TRT-LLM standalone tests (weights-only GMS topology, no KV cache GMS)
# ---------------------------------------------------------------------------
@pytest.mark.trtllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_basic_quiesce_resume_trtllm(
request,
runtime_services_dynamic_ports,
predownload_models,
):
"""Weights-only quiesce/resume for TRT-LLM (no KV cache GMS)."""
with GMSProcessManager(request, TRTLLMWithGMSProcess, tags=("weights",)) as manager:
frontend_port = manager.frontend_port
weights_gms = manager.weights_gms
engine = manager.start_engine("engine")
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Initial inference failed",
success_message="Initial inference OK",
)
ws = wait_for_weights_state(weights_gms, ServerState.RO, timeout=60.0)
weights_hash = ws.memory_layout_hash
mem_before = get_gpu_memory_used()
assert engine.quiesce()["status"] == "ok"
mem_after = get_gpu_memory_used()
released = mem_before - mem_after
logger.info(
"TRT-LLM quiesce: %.2f -> %.2f GiB (freed %.0f MB)",
mem_before / (1 << 30),
mem_after / (1 << 30),
released / (1 << 20),
)
assert released > 0
wait_for_weights_state(
weights_gms, ServerState.COMMITTED, expected_hash=weights_hash
)
assert_weights_published_once(weights_gms.get_event_history().events)
assert engine.resume()["status"] == "ok"
mem_resumed = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"TRT-LLM resume",
mem_after,
mem_resumed,
released,
min_fraction=0.6,
)
wait_for_weights_state(weights_gms, ServerState.RO, expected_hash=weights_hash)
assert_weights_published_once(weights_gms.get_event_history().events)
assert_completion_ok(
frontend_port,
"Goodbye",
failure_message="Post-resume inference failed",
success_message="Post-resume inference OK",
)
logger.info("Memory freed: %.0f MB", released / (1 << 20))
@pytest.mark.trtllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_read_only_import_trtllm(
request,
runtime_services_dynamic_ports,
predownload_models,
):
"""A second TRT-LLM process with read_only_weights=True imports weights
from the committed layout published by the first, sharing GPU memory."""
with GMSProcessManager(request, TRTLLMWithGMSProcess, tags=("weights",)) as manager:
frontend_port = manager.frontend_port
weights_gms = manager.weights_gms
manager.start_engine("rw-engine")
ws = wait_for_weights_state(weights_gms, ServerState.RO, timeout=60.0)
weights_hash = ws.memory_layout_hash
manager.start_engine("ro-engine", read_only_weights=True)
wait_for_weights_state(
weights_gms,
ServerState.RO,
min_ro_sessions=1,
expected_hash=weights_hash,
timeout=60.0,
)
assert_completion_ok(
frontend_port,
"Hello",
failure_message="RW+RO inference failed",
success_message="RW+RO inference OK",
)
......@@ -15,6 +15,7 @@ from gpu_memory_service.server.fsm import ServerState
from tests.gpu_memory_service.common.runtime import (
GMSProcessManager,
SGLangWithGMSProcess,
TRTLLMWithGMSProcess,
VLLMWithGMSProcess,
get_gpu_memory_used,
)
......@@ -26,6 +27,7 @@ from tests.gpu_memory_service.flow_assertions import (
quiesce_engine,
wait_for_active_layout,
wait_for_resumed_layout,
wait_for_weights_state,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
......@@ -321,3 +323,133 @@ def test_gms_shadow_engine_failover_sglang(
request, runtime_services_dynamic_ports, predownload_models
):
_run_shadow_failover_test(request, SGLangWithGMSProcess)
# ---------------------------------------------------------------------------
# TRT-LLM standalone failover test (weights-only GMS, no KV cache GMS)
# ---------------------------------------------------------------------------
def _trtllm_quiesce(
weights_gms,
engine,
*,
label: str,
expected_hash: str | None = None,
):
"""Quiesce a weights-only TRT-LLM engine and return state tuple."""
wait_for_weights_state(
weights_gms,
ServerState.RO,
expected_hash=expected_hash,
timeout=60.0,
)
mem_before = get_gpu_memory_used()
assert engine.quiesce()["status"] == "ok"
mem_after = get_gpu_memory_used()
released = mem_before - mem_after
logger.info(
"%s: %.2f -> %.2f GiB (freed %.0f MB)",
label,
mem_before / (1 << 30),
mem_after / (1 << 30),
released / (1 << 20),
)
assert released > 0
ws = wait_for_weights_state(weights_gms, ServerState.COMMITTED)
return ws, released, mem_after
@pytest.mark.trtllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover_trtllm(
request, runtime_services_dynamic_ports, predownload_models
):
"""Weights-only shadow failover for TRT-LLM (no KV cache GMS)."""
with GMSProcessManager(request, TRTLLMWithGMSProcess, tags=("weights",)) as manager:
frontend_port = manager.frontend_port
weights_gms = manager.weights_gms
# Shadow A publishes weights, then quiesces.
shadow_a = manager.start_engine("shadow-a")
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Shadow A inference failed",
success_message="Shadow A inference OK",
)
ws_a, released_a, _ = _trtllm_quiesce(
weights_gms, shadow_a, label="Shadow A quiesce"
)
weights_hash = ws_a.memory_layout_hash
# Shadow B starts RO, then quiesces.
shadow_b = manager.start_engine("shadow-b", read_only_weights=True)
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Shadow B inference failed",
success_message="Shadow B inference OK",
)
_, _, mem_after_b = _trtllm_quiesce(
weights_gms,
shadow_b,
label="Shadow B quiesce",
expected_hash=weights_hash,
)
assert_weights_published_once(weights_gms.get_event_history().events)
# Primary starts RO.
primary = manager.start_engine("primary", read_only_weights=True)
assert_completion_ok(
frontend_port,
"Primary test",
failure_message="Primary inference failed",
success_message="Primary inference OK",
)
primary_mem = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"Primary active",
mem_after_b,
primary_mem,
released_a,
min_fraction=0.6,
)
wait_for_weights_state(
weights_gms,
ServerState.RO,
expected_hash=weights_hash,
min_ro_sessions=1,
)
# Kill primary, resume shadow A immediately (no KV blocking).
_kill_process_group(primary)
resume_result = shadow_a.resume(timeout=180)
assert resume_result["status"] == "ok"
shadow_mem = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"Shadow A resume",
mem_after_b,
shadow_mem,
released_a,
min_fraction=0.6,
)
wait_for_weights_state(
weights_gms,
ServerState.RO,
expected_hash=weights_hash,
min_ro_sessions=1,
)
assert_weights_published_once(weights_gms.get_event_history().events)
assert_completion_ok(
frontend_port,
"Post failover",
failure_message="Shadow after failover failed",
success_message="Shadow after failover OK",
retry_timeout=30.0,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment