Unverified Commit d96a2cf1 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

feat: add TRT-LLM sleep/wake integration with GMS (#7575)


Co-authored-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent 119149f2
...@@ -160,6 +160,23 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -160,6 +160,23 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=True, default=True,
help="Disable calling abort() on the TRT-LLM engine when a request is cancelled.", help="Disable calling abort() on the TRT-LLM engine when a request is cancelled.",
) )
add_argument(
g,
flag_name="--load-format",
env_var="DYN_TRTLLM_LOAD_FORMAT",
default="auto",
help="Model weight loading format passed to TensorRT-LLM (e.g. 'auto', 'gms').",
)
add_argument(
g,
flag_name="--model-loader-extra-config",
env_var="DYN_TRTLLM_MODEL_LOADER_EXTRA_CONFIG",
default="",
help=(
"JSON object passed as extra config to the model loader "
"(e.g. '{\"gms_read_only\": true}')."
),
)
add_argument( add_argument(
g, g,
flag_name="--disaggregation-mode", flag_name="--disaggregation-mode",
...@@ -453,6 +470,8 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -453,6 +470,8 @@ class DynamoTrtllmConfig(ConfigBase):
override_engine_args: str override_engine_args: str
publish_events_and_metrics: bool publish_events_and_metrics: bool
disable_request_abort: bool disable_request_abort: bool
load_format: str
model_loader_extra_config: str
guided_decoding_backend: Optional[str] = None guided_decoding_backend: Optional[str] = None
disaggregation_mode: DisaggregationMode disaggregation_mode: DisaggregationMode
......
...@@ -58,6 +58,98 @@ if TYPE_CHECKING: ...@@ -58,6 +58,98 @@ if TYPE_CHECKING:
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__)
class TRTLLMEngineQuiesceController:
"""Adapts TRT-LLM sleep/wake to the standard quiesce controller interface.
Two memory domains: KV cache via TRT-LLM collective_rpc, weights via GMS.
"""
def __init__(self, engine: TensorRTLLMEngine):
self._engine = engine
self._is_quiesced = False
@property
def is_quiesced(self) -> bool:
return self._is_quiesced
async def quiesce(self, tags: list[str] | None = None) -> bool:
if self._is_quiesced:
return False
tags = tags or ["kv_cache", "weights"]
if "kv_cache" in tags:
self._collective_rpc("sleep", ["kv_cache"])
if "weights" in tags:
self._release_gms_weights()
self._is_quiesced = True
return True
async def resume(self, tags: list[str] | None = None) -> bool:
if not self._is_quiesced:
return False
tags = tags or ["kv_cache", "weights"]
if "weights" in tags:
self._restore_gms_weights()
if "kv_cache" in tags:
self._collective_rpc("wakeup", ["kv_cache"])
return True
def mark_resumed(self) -> None:
self._is_quiesced = False
def _collective_rpc(self, method: str, rpc_tags: list[str]) -> None:
"""Call TRT-LLM collective_rpc for KV cache sleep/wake."""
rpc = getattr(self._engine.llm, "_collective_rpc", None)
if rpc is None:
logger.warning(
"TRT-LLM does not expose _collective_rpc; skipping %s", method
)
return
try:
rpc(method, args=(rpc_tags,), kwargs={}, non_block=False)
except Exception:
if method != "wakeup":
raise
# Some TRT-LLM versions use "wake_up" instead of "wakeup"
rpc("wake_up", args=(rpc_tags,), kwargs={}, non_block=False)
@staticmethod
def _release_gms_weights() -> None:
"""Release GMS-managed weight memory."""
try:
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
)
except ImportError:
return
manager = get_gms_client_memory_manager("weights")
if manager is None:
return
manager.unmap_all_vas()
manager.abort()
torch.cuda.synchronize()
torch.cuda.empty_cache()
@staticmethod
def _restore_gms_weights() -> None:
"""Restore GMS-managed weight memory."""
try:
from gpu_memory_service.client.torch.allocator import (
get_gms_client_memory_manager,
)
from gpu_memory_service.integrations.trtllm.model_loader import (
get_gms_lock_mode,
)
except ImportError:
return
manager = get_gms_client_memory_manager("weights")
if manager is None or not manager.is_unmapped:
return
manager.connect(get_gms_lock_mode())
manager.remap_all_vas()
class _Abortable(Protocol): class _Abortable(Protocol):
"""Structural type for objects that support abort(). Satisfied by both """Structural type for objects that support abort(). Satisfied by both
...@@ -127,6 +219,7 @@ class RequestHandlerConfig: ...@@ -127,6 +219,7 @@ class RequestHandlerConfig:
metrics_collector: Optional["MetricsCollector"] = None metrics_collector: Optional["MetricsCollector"] = None
kv_block_size: int = 32 kv_block_size: int = 32
shutdown_event: Optional[asyncio.Event] = None shutdown_event: Optional[asyncio.Event] = None
generate_endpoint: Optional[Any] = None
encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB
disable_request_abort: bool = True disable_request_abort: bool = True
additional_metrics: Optional["AdditionalMetricsCollector"] = None additional_metrics: Optional["AdditionalMetricsCollector"] = None
...@@ -160,10 +253,19 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -160,10 +253,19 @@ class HandlerBase(BaseGenerativeHandler):
self.runtime = config.runtime self.runtime = config.runtime
self.kv_block_size: int = config.kv_block_size self.kv_block_size: int = config.kv_block_size
self.shutdown_event = config.shutdown_event self.shutdown_event = config.shutdown_event
self.generate_endpoint = config.generate_endpoint
self.disable_request_abort = config.disable_request_abort self.disable_request_abort = config.disable_request_abort
self.additional_metrics = config.additional_metrics self.additional_metrics = config.additional_metrics
self.max_seq_len = config.max_seq_len self.max_seq_len = config.max_seq_len
self.disagg_machine_id = config.disagg_machine_id self.disagg_machine_id = config.disagg_machine_id
# Sleep/wake state
self._quiesce_lock = asyncio.Lock()
self._inflight_lock = asyncio.Lock()
self._inflight_requests = 0
self._no_inflight_requests = asyncio.Event()
self._no_inflight_requests.set()
self._quiesce_controller = TRTLLMEngineQuiesceController(config.engine)
self._reject_new_requests = False
def check_error(self, result: dict) -> bool: def check_error(self, result: dict) -> bool:
""" """
...@@ -176,6 +278,96 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -176,6 +278,96 @@ class HandlerBase(BaseGenerativeHandler):
result["finish_reason"] == "stop" or result["finish_reason"] == "error" result["finish_reason"] == "stop" or result["finish_reason"] == "error"
) )
# ------------------------------------------------------------------
# In-flight request tracking (used by sleep/wake)
# ------------------------------------------------------------------
async def _set_reject_new_requests(self, reject: bool) -> None:
async with self._inflight_lock:
self._reject_new_requests = reject
async def _mark_request_started(self) -> bool:
async with self._inflight_lock:
if self._reject_new_requests:
return False
self._inflight_requests += 1
self._no_inflight_requests.clear()
return True
async def _mark_request_finished(self) -> None:
async with self._inflight_lock:
if self._inflight_requests == 0:
return
self._inflight_requests -= 1
if self._inflight_requests == 0:
self._no_inflight_requests.set()
async def _wait_for_inflight_requests(self, timeout_s: float) -> None:
try:
await asyncio.wait_for(self._no_inflight_requests.wait(), timeout_s)
except asyncio.TimeoutError as exc:
async with self._inflight_lock:
inflight = self._inflight_requests
raise RuntimeError(
f"Timed out waiting for {inflight} in-flight request(s) to finish"
) from exc
# ------------------------------------------------------------------
# Sleep / wake public API (delegates to TRTLLMEngineQuiesceController)
# ------------------------------------------------------------------
async def release_memory_occupation(self, body: dict) -> dict:
"""Release GPU memory: unregister endpoint, drain requests, quiesce engine."""
body = body or {}
tags = body.get("tags")
async with self._quiesce_lock:
if self._quiesce_controller.is_quiesced:
return {"status": "ok", "message": "Memory already released"}
try:
await self._set_reject_new_requests(True)
if self.generate_endpoint is not None:
await self.generate_endpoint.unregister_endpoint_instance()
timeout_s = float(body.get("timeout_s", 30.0))
await self._wait_for_inflight_requests(timeout_s)
await self._quiesce_controller.quiesce(tags)
return {"status": "ok", "message": "Memory released"}
except Exception as exc:
logger.error("release_memory_occupation failed: %s", exc)
# Rollback: TRT-LLM has no pause_generation(), so we
# manually unregistered the endpoint and set reject flag
# above. Restore both on failure.
if self.generate_endpoint is not None:
await self.generate_endpoint.register_endpoint_instance()
await self._set_reject_new_requests(False)
return {"status": "error", "message": str(exc)}
async def resume_memory_occupation(self, body: dict) -> dict:
"""Restore GPU memory: resume engine, re-register endpoint."""
body = body or {}
tags = body.get("tags")
async with self._quiesce_lock:
if not self._quiesce_controller.is_quiesced:
return {"status": "ok", "message": "Memory already resumed"}
try:
await self._quiesce_controller.resume(tags)
if self.generate_endpoint is not None:
await self.generate_endpoint.register_endpoint_instance()
await self._set_reject_new_requests(False)
self._quiesce_controller.mark_resumed()
return {"status": "ok", "message": "Memory resumed"}
except Exception as exc:
logger.error("resume_memory_occupation failed: %s", exc)
return {"status": "error", "message": str(exc)}
@staticmethod @staticmethod
def _extract_logprobs( def _extract_logprobs(
output, num_output_tokens_so_far: int output, num_output_tokens_so_far: int
...@@ -719,6 +911,31 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -719,6 +911,31 @@ class HandlerBase(BaseGenerativeHandler):
context: Context, context: Context,
embeddings: Optional[Union[torch.Tensor, dict]] = None, embeddings: Optional[Union[torch.Tensor, dict]] = None,
ep_disaggregated_params: Optional[DisaggregatedParams] = None, ep_disaggregated_params: Optional[DisaggregatedParams] = None,
) -> AsyncGenerator[dict, None]:
"""Track in-flight count, reject during sleep, then delegate to implementation."""
started = await self._mark_request_started()
if not started:
yield {
"finish_reason": {
"error": "Worker is temporarily rejecting new requests"
},
"token_ids": [],
}
return
try:
async for chunk in self._generate_locally_impl(
request, context, embeddings, ep_disaggregated_params
):
yield chunk
finally:
await self._mark_request_finished()
async def _generate_locally_impl(
self,
request: dict,
context: Context,
embeddings: Optional[Union[torch.Tensor, dict]] = None,
ep_disaggregated_params: Optional[DisaggregatedParams] = None,
) -> AsyncGenerator[dict, None]: ) -> AsyncGenerator[dict, None]:
""" """
Generate responses based on the disaggregation mode in the request. Generate responses based on the disaggregation mode in the request.
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for TRT-LLM sleep/wake handler logic.
Tests cover in-flight tracking, reject-flag, and sleep/wake delegation to
TRTLLMEngineQuiesceController without requiring a real GPU or TRT-LLM engine.
"""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
import torch
if not torch.cuda.is_available():
pytest.skip(
"Skipping: CUDA not available (tensorrt_llm import requires GPU).",
allow_module_level=True,
)
from dynamo.trtllm.request_handlers.handler_base import HandlerBase
pytestmark = [
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.pre_merge,
pytest.mark.gpu_0,
]
# ---------------------------------------------------------------------------
# Test fixture helpers
# ---------------------------------------------------------------------------
class _ConcreteHandler(HandlerBase):
async def generate(self, request, context):
yield {}
def _make_handler() -> _ConcreteHandler:
"""Create a HandlerBase subclass with mocked quiesce controller and endpoint."""
handler = _ConcreteHandler.__new__(_ConcreteHandler)
handler.generate_endpoint = SimpleNamespace(
unregister_endpoint_instance=AsyncMock(),
register_endpoint_instance=AsyncMock(),
)
handler._quiesce_lock = asyncio.Lock()
handler._inflight_lock = asyncio.Lock()
handler._inflight_requests = 0
handler._no_inflight_requests = asyncio.Event()
handler._no_inflight_requests.set()
handler._reject_new_requests = False
# Mock the quiesce controller that release/resume delegate to
handler._quiesce_controller = MagicMock()
handler._quiesce_controller.is_quiesced = False
handler._quiesce_controller.quiesce = AsyncMock(return_value=True)
handler._quiesce_controller.resume = AsyncMock(return_value=True)
handler._quiesce_controller.mark_resumed = MagicMock()
return handler
# ---------------------------------------------------------------------------
# In-flight tracking
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mark_request_started_respects_reject_flag():
handler = _make_handler()
await handler._set_reject_new_requests(True)
assert not await handler._mark_request_started()
assert handler._inflight_requests == 0
@pytest.mark.asyncio
async def test_mark_request_started_and_finished():
handler = _make_handler()
assert await handler._mark_request_started()
assert handler._inflight_requests == 1
assert not handler._no_inflight_requests.is_set()
await handler._mark_request_finished()
assert handler._inflight_requests == 0
assert handler._no_inflight_requests.is_set()
@pytest.mark.asyncio
async def test_mark_request_finished_is_idempotent():
handler = _make_handler()
# Extra call when count is already 0 must not underflow
await handler._mark_request_finished()
assert handler._inflight_requests == 0
# ---------------------------------------------------------------------------
# release_memory_occupation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_release_is_noop_when_already_quiesced():
handler = _make_handler()
handler._quiesce_controller.is_quiesced = True
result = await handler.release_memory_occupation({})
assert result["status"] == "ok"
assert "already released" in result["message"]
handler._quiesce_controller.quiesce.assert_not_called()
@pytest.mark.asyncio
async def test_release_returns_error_for_non_numeric_timeout():
handler = _make_handler()
result = await handler.release_memory_occupation({"timeout_s": "bad"})
assert result["status"] == "error"
@pytest.mark.asyncio
async def test_release_delegates_to_quiesce_controller():
handler = _make_handler()
result = await handler.release_memory_occupation({})
assert result["status"] == "ok"
handler._quiesce_controller.quiesce.assert_awaited_once_with(None)
@pytest.mark.asyncio
async def test_release_passes_tags_to_controller():
handler = _make_handler()
result = await handler.release_memory_occupation({"tags": ["weights"]})
assert result["status"] == "ok"
handler._quiesce_controller.quiesce.assert_awaited_once_with(["weights"])
@pytest.mark.asyncio
async def test_release_unregisters_endpoint_and_restores_on_error():
handler = _make_handler()
handler._quiesce_controller.quiesce = AsyncMock(
side_effect=RuntimeError("engine error")
)
result = await handler.release_memory_occupation({})
assert result["status"] == "error"
handler.generate_endpoint.unregister_endpoint_instance.assert_called_once()
handler.generate_endpoint.register_endpoint_instance.assert_called_once()
assert not handler._reject_new_requests
# ---------------------------------------------------------------------------
# resume_memory_occupation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_resume_is_noop_when_not_quiesced():
handler = _make_handler()
result = await handler.resume_memory_occupation({})
assert result["status"] == "ok"
assert "already resumed" in result["message"]
handler._quiesce_controller.resume.assert_not_called()
@pytest.mark.asyncio
async def test_release_and_resume_round_trip():
handler = _make_handler()
release = await handler.release_memory_occupation({})
assert release["status"] == "ok"
# After release, controller reports quiesced
handler._quiesce_controller.is_quiesced = True
resume = await handler.resume_memory_occupation({})
assert resume["status"] == "ok"
handler._quiesce_controller.resume.assert_awaited_once()
handler._quiesce_controller.mark_resumed.assert_called_once()
assert not handler._reject_new_requests
handler.generate_endpoint.register_endpoint_instance.assert_called_once()
@pytest.mark.asyncio
async def test_resume_passes_tags_to_controller():
handler = _make_handler()
handler._quiesce_controller.is_quiesced = True
result = await handler.resume_memory_occupation({"tags": ["kv_cache"]})
assert result["status"] == "ok"
handler._quiesce_controller.resume.assert_awaited_once_with(["kv_cache"])
# ---------------------------------------------------------------------------
# generate_locally inflight guard
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_generate_locally_rejected_when_sleeping():
handler = _make_handler()
handler._reject_new_requests = True
chunks = []
ctx = MagicMock()
async for chunk in handler.generate_locally({"token_ids": []}, ctx):
chunks.append(chunk)
assert len(chunks) == 1
assert "error" in str(chunks[0].get("finish_reason", ""))
...@@ -22,7 +22,11 @@ from tensorrt_llm.llmapi import ( ...@@ -22,7 +22,11 @@ from tensorrt_llm.llmapi import (
SchedulerConfig, SchedulerConfig,
) )
from tensorrt_llm.llmapi.llm import SamplingParams from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.llmapi.llm_args import TOKENIZER_ALIASES, KvCacheConnectorConfig from tensorrt_llm.llmapi.llm_args import (
TOKENIZER_ALIASES,
KvCacheConnectorConfig,
LoadFormat,
)
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.metrics import MetricsCollector from tensorrt_llm.metrics import MetricsCollector
...@@ -127,6 +131,42 @@ def _warn_override_collisions(target: dict, source: dict, path: str = "") -> Non ...@@ -127,6 +131,42 @@ def _warn_override_collisions(target: dict, source: dict, path: str = "") -> Non
) )
def _parse_model_loader_extra_config(raw: object) -> dict[str, object]:
"""Parse --model-loader-extra-config into a dict. Accepts a dict or a JSON string."""
if raw is None or raw == "":
return {}
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError as exc:
raise ValueError(
f"Invalid JSON in --model-loader-extra-config: {exc}"
) from exc
if not isinstance(parsed, dict):
raise ValueError("--model-loader-extra-config must decode to a JSON object")
return parsed
raise ValueError(
"--model-loader-extra-config must be a JSON object string or a dict"
)
def _register_memory_routes(runtime, handler) -> None:
runtime.register_engine_route(
"release_memory_occupation",
handler.release_memory_occupation,
)
runtime.register_engine_route(
"resume_memory_occupation",
handler.resume_memory_occupation,
)
logging.info(
"Registered engine routes: "
"/engine/release_memory_occupation, /engine/resume_memory_occupation"
)
async def init_llm_worker( async def init_llm_worker(
runtime: DistributedRuntime, runtime: DistributedRuntime,
config: Config, config: Config,
...@@ -187,6 +227,40 @@ async def init_llm_worker( ...@@ -187,6 +227,40 @@ async def init_llm_worker(
) )
kv_connector_config = build_kv_connector_config(config) kv_connector_config = build_kv_connector_config(config)
try:
model_loader_extra_config = _parse_model_loader_extra_config(
config.model_loader_extra_config
)
except ValueError as exc:
logging.error("%s", exc)
sys.exit(1)
if config.load_format == "gms":
try:
from gpu_memory_service.integrations.trtllm import setup_gms
except ImportError as exc:
raise RuntimeError(
"gpu-memory-service is required for --load-format gms. "
"Install or update the package."
) from exc
setup_gms(model_loader_extra_config)
logging.info(
"TRT-LLM GMS integration enabled (extra=%s)", model_loader_extra_config
)
# Resolve load_format for engine args. GMS patches are active regardless;
# fall back to "auto" if TRT-LLM doesn't recognise "gms" as a LoadFormat.
engine_load_format = config.load_format
if config.load_format == "gms":
try:
LoadFormat(config.load_format)
except (ValueError, KeyError):
logging.warning(
"TensorRT-LLM does not recognise load_format='gms'; "
"using 'auto' while GMS patches remain active."
)
engine_load_format = "auto"
arg_map = { arg_map = {
"model": model_path, "model": model_path,
"scheduler_config": scheduler_config, "scheduler_config": scheduler_config,
...@@ -209,6 +283,16 @@ async def init_llm_worker( ...@@ -209,6 +283,16 @@ async def init_llm_worker(
"kv_connector_config": kv_connector_config, "kv_connector_config": kv_connector_config,
} }
arg_map["load_format"] = engine_load_format
# Enable sleep_config when GMS manages weights — required for GMS
# unmap/remap. Conditional because SleepConfig contains unpicklable
# lambdas that break MPI-based multi-rank distribution.
if config.load_format == "gms":
from tensorrt_llm.llmapi.llm_args import SleepConfig
arg_map["sleep_config"] = SleepConfig()
# Add guided decoding backend if specified # Add guided decoding backend if specified
if config.guided_decoding_backend is not None: if config.guided_decoding_backend is not None:
arg_map["guided_decoding_backend"] = config.guided_decoding_backend arg_map["guided_decoding_backend"] = config.guided_decoding_backend
...@@ -517,6 +601,7 @@ async def init_llm_worker( ...@@ -517,6 +601,7 @@ async def init_llm_worker(
disaggregation_mode=config.disaggregation_mode, disaggregation_mode=config.disaggregation_mode,
encode_client=encode_client, encode_client=encode_client,
multimodal_processor=multimodal_processor, multimodal_processor=multimodal_processor,
generate_endpoint=endpoint,
connector=connector, connector=connector,
runtime=runtime, # Pass runtime for graceful shutdown runtime=runtime, # Pass runtime for graceful shutdown
metrics_collector=metrics_collector, metrics_collector=metrics_collector,
...@@ -593,6 +678,8 @@ async def init_llm_worker( ...@@ -593,6 +678,8 @@ async def init_llm_worker(
) as publisher: ) as publisher:
handler_config.publisher = publisher handler_config.publisher = publisher
handler = RequestHandlerFactory().get_request_handler(handler_config) handler = RequestHandlerFactory().get_request_handler(handler_config)
if config.load_format == "gms":
_register_memory_routes(runtime, handler)
encoder_cache = getattr(handler, "_encoder_cache", None) encoder_cache = getattr(handler, "_encoder_cache", None)
if encoder_cache is not None: if encoder_cache is not None:
...@@ -602,7 +689,6 @@ async def init_llm_worker( ...@@ -602,7 +689,6 @@ async def init_llm_worker(
model_name=model_name_for_metrics, model_name=model_name_for_metrics,
component_name=config.component, component_name=config.component,
) )
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
handler.generate, handler.generate,
metrics_labels=metrics_labels, metrics_labels=metrics_labels,
...@@ -614,6 +700,8 @@ async def init_llm_worker( ...@@ -614,6 +700,8 @@ async def init_llm_worker(
consolidator_publisher.shutdown() consolidator_publisher.shutdown()
else: else:
handler = RequestHandlerFactory().get_request_handler(handler_config) handler = RequestHandlerFactory().get_request_handler(handler_config)
if config.load_format == "gms":
_register_memory_routes(runtime, handler)
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
handler.generate, health_check_payload=health_check_payload handler.generate, health_check_payload=health_check_payload
) )
...@@ -98,7 +98,7 @@ trtllm: ...@@ -98,7 +98,7 @@ trtllm:
runtime_image_tag: 26.02-cuda13.1-runtime-ubuntu24.04 runtime_image_tag: 26.02-cuda13.1-runtime-ubuntu24.04
nixl_ref: 0.10.1 nixl_ref: 0.10.1
enable_media_ffmpeg: "false" enable_media_ffmpeg: "false"
enable_gpu_memory_service: "false" enable_gpu_memory_service: "true"
enable_kvbm: "true" enable_kvbm: "true"
python_version: "3.12" python_version: "3.12"
index_url: https://pypi.nvidia.com/ index_url: https://pypi.nvidia.com/
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GPU Memory Service integration for TensorRT-LLM.
Usage:
import json
from gpu_memory_service.integrations.trtllm import setup_gms
if config.load_format == "gms":
raw = config.model_loader_extra_config
extra = json.loads(raw) if isinstance(raw, str) else (raw or None)
setup_gms(extra)
"""
from __future__ import annotations
import logging
from typing import Any
from gpu_memory_service.integrations.common import patch_empty_cache
from gpu_memory_service.integrations.common.utils import (
get_gms_lock_mode as _resolve_lock_mode,
)
from gpu_memory_service.integrations.trtllm.model_loader import (
get_gms_lock_mode,
patch_model_loader,
set_gms_enabled,
set_gms_lock_mode,
)
logger = logging.getLogger(__name__)
__all__ = ["setup_gms", "get_gms_lock_mode"]
def setup_gms(model_loader_extra_config: dict[str, Any] | None = None) -> None:
"""Set up GMS integration for TensorRT-LLM. Call once before creating the engine."""
extra = model_loader_extra_config or {}
lock_mode = _resolve_lock_mode(extra)
set_gms_enabled(True)
set_gms_lock_mode(lock_mode)
patch_empty_cache()
patch_model_loader()
logger.info("[GMS] TensorRT-LLM integration enabled (mode=%s)", lock_mode)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""TensorRT-LLM model loader patches for GPU Memory Service integration.
This module patches TensorRT-LLM's ModelLoader to load weights through GMS,
enabling VA-stable weight sharing and sleep/wake with memory release.
Two modes:
- RW (first loader): loads weights from disk, allocates via GMS pool, commits.
- RO (subsequent loaders): materializes model tensors from the committed layout.
"""
from __future__ import annotations
import copy
import logging
from typing import TYPE_CHECKING
import torch
from gpu_memory_service.client.torch.allocator import (
get_or_create_gms_client_memory_manager,
gms_use_mem_pool,
)
from gpu_memory_service.client.torch.module import materialize_module_from_gms
from gpu_memory_service.common.locks import GrantedLockType, RequestedLockType
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.integrations.common.utils import finalize_gms_write
if TYPE_CHECKING:
from gpu_memory_service.client.memory_manager import GMSClientMemoryManager
logger = logging.getLogger(__name__)
_model_loader_patched = False
_gms_enabled = False
_gms_lock_mode = RequestedLockType.RW_OR_RO
_last_imported_weights_bytes: int = 0
def set_gms_enabled(enabled: bool) -> None:
global _gms_enabled
_gms_enabled = enabled
def set_gms_lock_mode(mode: RequestedLockType) -> None:
global _gms_lock_mode
_gms_lock_mode = mode
def get_gms_lock_mode() -> RequestedLockType:
return _gms_lock_mode
def get_imported_weights_bytes() -> int:
"""Return total bytes of weights imported/published by the last model load."""
return _last_imported_weights_bytes
def patch_model_loader() -> None:
"""Patch TensorRT-LLM's ModelLoader to load weights through GMS.
Idempotent — safe to call multiple times.
"""
global _model_loader_patched
if _model_loader_patched:
return
import tensorrt_llm._torch.pyexecutor.model_loader as _trt_loader
_original_load = _trt_loader.ModelLoader.load
_original_get_rank_model_storage = _trt_loader.get_rank_model_storage
def patched_get_rank_model_storage(model) -> int:
imported = get_imported_weights_bytes()
if imported > 0:
return imported
return int(_original_get_rank_model_storage(model))
def patched_load(self, checkpoint_dir: str, checkpoint_loader):
if not _gms_enabled:
return _original_load(self, checkpoint_dir, checkpoint_loader)
return _gms_load(
self=self,
checkpoint_dir=checkpoint_dir,
checkpoint_loader=checkpoint_loader,
original_load=_original_load,
)
_trt_loader.get_rank_model_storage = patched_get_rank_model_storage
_trt_loader.ModelLoader.load = patched_load
_model_loader_patched = True
logger.info("[GMS] Patched TensorRT-LLM ModelLoader.load")
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _gms_load(self, checkpoint_dir: str, checkpoint_loader, original_load):
"""Route to RW (write) or RO (read) load path based on granted lock type."""
# Neutralize TRT-LLM's model_weights_memory_tag to prevent its VMM scope
# from overriding the GMS allocator. When sleep_config is set, TRT-LLM
# wraps allocation in virtual_memory_scope(model_weights) — a nested scope
# that would steal allocations away from GMS's gms_use_mem_pool.
saved_tag = getattr(self, "model_weights_memory_tag", None)
self.model_weights_memory_tag = None
device_index = torch.cuda.current_device()
gms_client = get_or_create_gms_client_memory_manager(
get_socket_path(device_index, "weights"),
device_index,
mode=_gms_lock_mode,
tag="weights",
)
try:
return _gms_load_inner(
self,
gms_client,
device_index,
checkpoint_dir,
checkpoint_loader,
original_load,
)
finally:
self.model_weights_memory_tag = saved_tag
def _gms_load_inner(
self, gms_client, device_index, checkpoint_dir, checkpoint_loader, original_load
):
if gms_client.granted_lock_type == GrantedLockType.RO:
return _load_ro(
self=self,
checkpoint_dir=checkpoint_dir,
checkpoint_loader=checkpoint_loader,
gms_client=gms_client,
device_index=device_index,
)
return _load_rw(
self=self,
checkpoint_dir=checkpoint_dir,
checkpoint_loader=checkpoint_loader,
gms_client=gms_client,
device_index=device_index,
original_load=original_load,
)
def _load_rw(
self, checkpoint_dir, checkpoint_loader, gms_client, device_index, original_load
):
"""RW path: load weights from disk into the GMS memory pool, then commit."""
global _last_imported_weights_bytes
target_device = torch.device("cuda", device_index)
with gms_use_mem_pool("weights", target_device):
model, moe_load_balancer = original_load(
self, checkpoint_dir, checkpoint_loader
)
_move_untracked_params(model, gms_client, target_device)
torch.cuda.empty_cache()
_last_imported_weights_bytes = finalize_gms_write(gms_client, model)
logger.info(
"[GMS] TRT-LLM RW: published %.2f GiB",
_last_imported_weights_bytes / (1 << 30),
)
return model, moe_load_balancer
def _load_ro(self, checkpoint_dir, checkpoint_loader, gms_client, device_index):
"""RO path: skip disk I/O, materialize tensors from the committed GMS layout."""
global _last_imported_weights_bytes
from tensorrt_llm._torch.models import AutoModelForCausalLM
from tensorrt_llm._torch.models.modeling_utils import MetaInitMode, timing
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import (
MoeLoadBalancer,
maybe_create_moe_load_balancer,
)
config = self._load_and_validate_config(checkpoint_dir, checkpoint_loader)
with (
timing("Model init total"),
maybe_create_moe_load_balancer(config, self.mapping) as moe_load_balancer,
):
try:
with MetaInitMode():
model = AutoModelForCausalLM.from_config(copy.deepcopy(config))
except Exception as exc:
raise RuntimeError(
"GMS RO path requires successful MetaInitMode model construction"
) from exc
# Some models register cross-layer references like next_attn here.
if hasattr(model, "post_load_weights"):
model.post_load_weights()
materialize_module_from_gms(gms_client, model, device_index=device_index)
_last_imported_weights_bytes = int(gms_client.total_bytes)
logger.info(
"[GMS] TRT-LLM RO: imported %.2f GiB",
_last_imported_weights_bytes / (1 << 30),
)
for module in model.modules():
if hasattr(module, "post_load_weights") and not getattr(
module, "_weights_removed", False
):
module.post_load_weights()
if isinstance(moe_load_balancer, MoeLoadBalancer):
moe_load_balancer.register_weight_slots_after_to_cuda()
moe_load_balancer.finalize_model()
torch.cuda.current_stream().synchronize()
return model, moe_load_balancer
def _ptr_in_gms(gms_client: "GMSClientMemoryManager", ptr: int) -> bool:
"""Return True if the given CUDA VA is within any active GMS mapping."""
for va, mapping in gms_client.mappings.items():
if va <= ptr < va + mapping.aligned_size:
return True
return False
def _move_untracked_params(
model: torch.nn.Module,
gms_client: "GMSClientMemoryManager",
target_device: torch.device,
) -> None:
"""Move CUDA parameters that were allocated outside the GMS pool into it.
TRT-LLM may allocate some parameters outside the pluggable-allocator scope.
This ensures all weight tensors end up tracked by GMS before we commit.
"""
from gpu_memory_service.client.torch.module import _iter_module_tensors
from gpu_memory_service.client.torch.tensor import _tensor_from_pointer
device_index = (
torch.cuda.current_device()
if target_device.index is None
else int(target_device.index)
)
seen: set[int] = set()
with torch.no_grad():
for _name, tensor, tensor_type in _iter_module_tensors(model):
if tensor_type != "parameter" or tensor is None or not tensor.is_cuda:
continue
storage_ptr = tensor.storage().data_ptr()
if storage_ptr in seen:
continue
seen.add(storage_ptr)
if _ptr_in_gms(gms_client, int(tensor.data_ptr())):
continue
# Allocate a new mapping and copy the tensor into it
nbytes = _storage_nbytes(tensor)
base_va = gms_client.create_mapping(size=nbytes, tag="weights")
replacement = _tensor_from_pointer(
int(base_va),
list(tensor.shape),
list(tensor.stride()),
tensor.dtype,
device_index,
)
replacement.copy_(tensor)
tensor.data = replacement
def _storage_nbytes(tensor: torch.Tensor) -> int:
if tensor.numel() == 0:
return 0
element_size = int(tensor.element_size())
shape = list(tensor.shape)
stride = list(tensor.stride())
if not shape:
return element_size
max_offset = sum(
abs(int(s)) * (int(d) - 1)
for s, d in zip(stride, shape, strict=True)
if int(d) > 1
)
return int((max_offset + 1) * element_size)
...@@ -77,6 +77,7 @@ setup( ...@@ -77,6 +77,7 @@ setup(
"gpu_memory_service.integrations", "gpu_memory_service.integrations",
"gpu_memory_service.integrations.common", "gpu_memory_service.integrations.common",
"gpu_memory_service.integrations.sglang", "gpu_memory_service.integrations.sglang",
"gpu_memory_service.integrations.trtllm",
"gpu_memory_service.integrations.vllm", "gpu_memory_service.integrations.vllm",
], ],
package_dir={ package_dir={
...@@ -93,6 +94,7 @@ setup( ...@@ -93,6 +94,7 @@ setup(
"gpu_memory_service.integrations": "integrations", "gpu_memory_service.integrations": "integrations",
"gpu_memory_service.integrations.common": "integrations/common", "gpu_memory_service.integrations.common": "integrations/common",
"gpu_memory_service.integrations.sglang": "integrations/sglang", "gpu_memory_service.integrations.sglang": "integrations/sglang",
"gpu_memory_service.integrations.trtllm": "integrations/trtllm",
"gpu_memory_service.integrations.vllm": "integrations/vllm", "gpu_memory_service.integrations.vllm": "integrations/vllm",
}, },
package_data={ package_data={
......
...@@ -9,6 +9,7 @@ import json ...@@ -9,6 +9,7 @@ import json
import logging import logging
import os import os
import sys import sys
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import ExitStack from contextlib import ExitStack
...@@ -34,6 +35,23 @@ def get_gpu_memory_used(device: int = 0) -> int: ...@@ -34,6 +35,23 @@ def get_gpu_memory_used(device: int = 0) -> int:
pynvml.nvmlShutdown() pynvml.nvmlShutdown()
def wait_for_memory_drop(
baseline_bytes: int,
*,
timeout_s: float = 30.0,
poll_interval_s: float = 0.5,
) -> int:
"""Poll until GPU memory drops below *baseline_bytes*, then return current usage."""
deadline = time.monotonic() + timeout_s
current = get_gpu_memory_used()
while time.monotonic() < deadline:
if current < baseline_bytes:
return current
time.sleep(poll_interval_s)
current = get_gpu_memory_used()
return current
class GMSProcessManager: class GMSProcessManager:
"""Start the shared GMS daemons and frontend for one test scenario.""" """Start the shared GMS daemons and frontend for one test scenario."""
...@@ -43,10 +61,12 @@ class GMSProcessManager: ...@@ -43,10 +61,12 @@ class GMSProcessManager:
engine_cls, engine_cls,
*, *,
read_only_weights: bool = False, read_only_weights: bool = False,
tags: tuple[str, ...] = ("weights", "kv_cache"),
): ):
self._request = request self._request = request
self._engine_cls = engine_cls self._engine_cls = engine_cls
self._read_only_weights = read_only_weights self._read_only_weights = read_only_weights
self._tags = tags
self._stack: ExitStack | None = None self._stack: ExitStack | None = None
self.frontend_port: int | None = None self.frontend_port: int | None = None
self.weights_gms = None self.weights_gms = None
...@@ -57,8 +77,14 @@ class GMSProcessManager: ...@@ -57,8 +77,14 @@ class GMSProcessManager:
def __enter__(self): def __enter__(self):
stack = ExitStack() stack = ExitStack()
try: try:
self.weights_gms = stack.enter_context(GMSServer(device=0, tag="weights")) if "weights" in self._tags:
self.kv_cache_gms = stack.enter_context(GMSServer(device=0, tag="kv_cache")) self.weights_gms = stack.enter_context(
GMSServer(device=0, tag="weights")
)
if "kv_cache" in self._tags:
self.kv_cache_gms = stack.enter_context(
GMSServer(device=0, tag="kv_cache")
)
frontend = stack.enter_context( frontend = stack.enter_context(
DynamoFrontendProcess( DynamoFrontendProcess(
self._request, self._request,
...@@ -306,6 +332,96 @@ class VLLMWithGMSProcess(GMSEngineProcess): ...@@ -306,6 +332,96 @@ class VLLMWithGMSProcess(GMSEngineProcess):
return {"level": 2} return {"level": 2}
class TRTLLMWithGMSProcess(GMSEngineProcess):
"""TensorRT-LLM engine with GMS weights + sleep/wake enabled."""
quiesce_route = "release_memory_occupation"
resume_route = "resume_memory_occupation"
# Override via environment variables for CI or custom setups.
TRTLLM_GMS_MODEL_NAME = os.environ.get(
"TRTLLM_GMS_MODEL_NAME", FAULT_TOLERANCE_MODEL_NAME
)
TRTLLM_GMS_FREE_GPU_MEMORY_FRACTION = os.environ.get(
"TRTLLM_GMS_FREE_GPU_MEMORY_FRACTION", "0.9"
)
TRTLLM_GMS_MAX_SEQ_LEN = os.environ.get("TRTLLM_GMS_MAX_SEQ_LEN", "256")
TRTLLM_GMS_MAX_NUM_TOKENS = os.environ.get("TRTLLM_GMS_MAX_NUM_TOKENS", "256")
TRTLLM_GMS_OVERRIDE_ENGINE_ARGS = os.environ.get(
"TRTLLM_GMS_OVERRIDE_ENGINE_ARGS", ""
)
def __init__(
self,
request,
frontend_port: int,
*,
engine_id: str,
read_only_weights: bool = False,
override_engine_args: str | None = None,
):
reserved_ports = allocate_ports(1)
self._override_engine_args = override_engine_args
try:
super().__init__(
request,
engine_id,
reserved_ports[0],
frontend_port,
reserved_ports,
read_only_weights=read_only_weights,
)
except Exception:
deallocate_ports(reserved_ports)
raise
def env_updates(self) -> dict[str, str]:
env = {
"CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES", "0"),
"TLLM_WORKER_USE_SINGLE_PROCESS": "1",
"MPI4PY_MPIABI": "openmpi",
"OMPI_MCA_coll_ucc_enable": "0",
}
venv = os.environ.get("VIRTUAL_ENV")
if venv:
venv_lib = os.path.join(venv, "lib")
existing = os.environ.get("LD_LIBRARY_PATH", "")
env["LD_LIBRARY_PATH"] = f"{venv_lib}:{existing}" if existing else venv_lib
return env
def command(self) -> list[str]:
command = [
sys.executable,
"-m",
"dynamo.trtllm",
"--model",
self.TRTLLM_GMS_MODEL_NAME,
"--gpus-per-node",
"1",
"--load-format",
"gms",
"--free-gpu-memory-fraction",
self.TRTLLM_GMS_FREE_GPU_MEMORY_FRACTION,
"--max-seq-len",
self.TRTLLM_GMS_MAX_SEQ_LEN,
"--max-num-tokens",
self.TRTLLM_GMS_MAX_NUM_TOKENS,
]
effective_override = self._override_engine_args
if effective_override is None:
effective_override = self.TRTLLM_GMS_OVERRIDE_ENGINE_ARGS
if effective_override:
command.extend(["--override-engine-args", effective_override])
extra_config = self.model_loader_extra_config()
if extra_config is not None:
command.extend(["--model-loader-extra-config", extra_config])
return command
def quiesce_payload(self) -> dict:
return {}
class SGLangWithGMSProcess(GMSEngineProcess): class SGLangWithGMSProcess(GMSEngineProcess):
quiesce_route = "release_memory_occupation" quiesce_route = "release_memory_occupation"
resume_route = "resume_memory_occupation" resume_route = "resume_memory_occupation"
......
...@@ -226,3 +226,31 @@ def assert_kv_history( ...@@ -226,3 +226,31 @@ def assert_kv_history(
] ]
assert len(clear_counts) >= cleared_layouts assert len(clear_counts) >= cleared_layouts
assert all(count > 0 for count in clear_counts[:cleared_layouts]) assert all(count > 0 for count in clear_counts[:cleared_layouts])
def wait_for_weights_state(
weights_gms,
expected_state,
*,
min_ro_sessions: int = 0,
expected_hash: str | None = None,
timeout: float = 30.0,
):
"""Poll until the weights GMS daemon reaches *expected_state*."""
deadline = time.monotonic() + timeout
while True:
ws = weights_gms.get_runtime_state()
if (
ws.state == expected_state
and ws.allocation_count > 0
and ws.memory_layout_hash
and ws.ro_session_count >= min_ro_sessions
and (expected_hash is None or ws.memory_layout_hash == expected_hash)
):
return ws
if time.monotonic() > deadline:
raise TimeoutError(
f"Weights: state={ws.state} (want {expected_state}), "
f"allocs={ws.allocation_count}, hash={ws.memory_layout_hash}"
)
time.sleep(0.1)
...@@ -6,10 +6,12 @@ from __future__ import annotations ...@@ -6,10 +6,12 @@ from __future__ import annotations
import logging import logging
import pytest import pytest
from gpu_memory_service.server.fsm import ServerState
from tests.gpu_memory_service.common.runtime import ( from tests.gpu_memory_service.common.runtime import (
GMSProcessManager, GMSProcessManager,
SGLangWithGMSProcess, SGLangWithGMSProcess,
TRTLLMWithGMSProcess,
VLLMWithGMSProcess, VLLMWithGMSProcess,
get_gpu_memory_used, get_gpu_memory_used,
) )
...@@ -20,6 +22,7 @@ from tests.gpu_memory_service.flow_assertions import ( ...@@ -20,6 +22,7 @@ from tests.gpu_memory_service.flow_assertions import (
assert_weights_published_once, assert_weights_published_once,
quiesce_engine, quiesce_engine,
wait_for_resumed_layout, wait_for_resumed_layout,
wait_for_weights_state,
) )
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
...@@ -137,3 +140,110 @@ def test_gms_basic_quiesce_resume_sglang( ...@@ -137,3 +140,110 @@ def test_gms_basic_quiesce_resume_sglang(
predownload_models, predownload_models,
): ):
_run_quiesce_resume_test(request, SGLangWithGMSProcess) _run_quiesce_resume_test(request, SGLangWithGMSProcess)
# ---------------------------------------------------------------------------
# TRT-LLM standalone tests (weights-only GMS topology, no KV cache GMS)
# ---------------------------------------------------------------------------
@pytest.mark.trtllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_basic_quiesce_resume_trtllm(
request,
runtime_services_dynamic_ports,
predownload_models,
):
"""Weights-only quiesce/resume for TRT-LLM (no KV cache GMS)."""
with GMSProcessManager(request, TRTLLMWithGMSProcess, tags=("weights",)) as manager:
frontend_port = manager.frontend_port
weights_gms = manager.weights_gms
engine = manager.start_engine("engine")
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Initial inference failed",
success_message="Initial inference OK",
)
ws = wait_for_weights_state(weights_gms, ServerState.RO, timeout=60.0)
weights_hash = ws.memory_layout_hash
mem_before = get_gpu_memory_used()
assert engine.quiesce()["status"] == "ok"
mem_after = get_gpu_memory_used()
released = mem_before - mem_after
logger.info(
"TRT-LLM quiesce: %.2f -> %.2f GiB (freed %.0f MB)",
mem_before / (1 << 30),
mem_after / (1 << 30),
released / (1 << 20),
)
assert released > 0
wait_for_weights_state(
weights_gms, ServerState.COMMITTED, expected_hash=weights_hash
)
assert_weights_published_once(weights_gms.get_event_history().events)
assert engine.resume()["status"] == "ok"
mem_resumed = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"TRT-LLM resume",
mem_after,
mem_resumed,
released,
min_fraction=0.6,
)
wait_for_weights_state(weights_gms, ServerState.RO, expected_hash=weights_hash)
assert_weights_published_once(weights_gms.get_event_history().events)
assert_completion_ok(
frontend_port,
"Goodbye",
failure_message="Post-resume inference failed",
success_message="Post-resume inference OK",
)
logger.info("Memory freed: %.0f MB", released / (1 << 20))
@pytest.mark.trtllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_read_only_import_trtllm(
request,
runtime_services_dynamic_ports,
predownload_models,
):
"""A second TRT-LLM process with read_only_weights=True imports weights
from the committed layout published by the first, sharing GPU memory."""
with GMSProcessManager(request, TRTLLMWithGMSProcess, tags=("weights",)) as manager:
frontend_port = manager.frontend_port
weights_gms = manager.weights_gms
manager.start_engine("rw-engine")
ws = wait_for_weights_state(weights_gms, ServerState.RO, timeout=60.0)
weights_hash = ws.memory_layout_hash
manager.start_engine("ro-engine", read_only_weights=True)
wait_for_weights_state(
weights_gms,
ServerState.RO,
min_ro_sessions=1,
expected_hash=weights_hash,
timeout=60.0,
)
assert_completion_ok(
frontend_port,
"Hello",
failure_message="RW+RO inference failed",
success_message="RW+RO inference OK",
)
...@@ -15,6 +15,7 @@ from gpu_memory_service.server.fsm import ServerState ...@@ -15,6 +15,7 @@ from gpu_memory_service.server.fsm import ServerState
from tests.gpu_memory_service.common.runtime import ( from tests.gpu_memory_service.common.runtime import (
GMSProcessManager, GMSProcessManager,
SGLangWithGMSProcess, SGLangWithGMSProcess,
TRTLLMWithGMSProcess,
VLLMWithGMSProcess, VLLMWithGMSProcess,
get_gpu_memory_used, get_gpu_memory_used,
) )
...@@ -26,6 +27,7 @@ from tests.gpu_memory_service.flow_assertions import ( ...@@ -26,6 +27,7 @@ from tests.gpu_memory_service.flow_assertions import (
quiesce_engine, quiesce_engine,
wait_for_active_layout, wait_for_active_layout,
wait_for_resumed_layout, wait_for_resumed_layout,
wait_for_weights_state,
) )
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -321,3 +323,133 @@ def test_gms_shadow_engine_failover_sglang( ...@@ -321,3 +323,133 @@ def test_gms_shadow_engine_failover_sglang(
request, runtime_services_dynamic_ports, predownload_models request, runtime_services_dynamic_ports, predownload_models
): ):
_run_shadow_failover_test(request, SGLangWithGMSProcess) _run_shadow_failover_test(request, SGLangWithGMSProcess)
# ---------------------------------------------------------------------------
# TRT-LLM standalone failover test (weights-only GMS, no KV cache GMS)
# ---------------------------------------------------------------------------
def _trtllm_quiesce(
weights_gms,
engine,
*,
label: str,
expected_hash: str | None = None,
):
"""Quiesce a weights-only TRT-LLM engine and return state tuple."""
wait_for_weights_state(
weights_gms,
ServerState.RO,
expected_hash=expected_hash,
timeout=60.0,
)
mem_before = get_gpu_memory_used()
assert engine.quiesce()["status"] == "ok"
mem_after = get_gpu_memory_used()
released = mem_before - mem_after
logger.info(
"%s: %.2f -> %.2f GiB (freed %.0f MB)",
label,
mem_before / (1 << 30),
mem_after / (1 << 30),
released / (1 << 20),
)
assert released > 0
ws = wait_for_weights_state(weights_gms, ServerState.COMMITTED)
return ws, released, mem_after
@pytest.mark.trtllm
@pytest.mark.e2e
@pytest.mark.gpu_1
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.timeout(600)
def test_gms_shadow_engine_failover_trtllm(
request, runtime_services_dynamic_ports, predownload_models
):
"""Weights-only shadow failover for TRT-LLM (no KV cache GMS)."""
with GMSProcessManager(request, TRTLLMWithGMSProcess, tags=("weights",)) as manager:
frontend_port = manager.frontend_port
weights_gms = manager.weights_gms
# Shadow A publishes weights, then quiesces.
shadow_a = manager.start_engine("shadow-a")
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Shadow A inference failed",
success_message="Shadow A inference OK",
)
ws_a, released_a, _ = _trtllm_quiesce(
weights_gms, shadow_a, label="Shadow A quiesce"
)
weights_hash = ws_a.memory_layout_hash
# Shadow B starts RO, then quiesces.
shadow_b = manager.start_engine("shadow-b", read_only_weights=True)
assert_completion_ok(
frontend_port,
"Hello",
failure_message="Shadow B inference failed",
success_message="Shadow B inference OK",
)
_, _, mem_after_b = _trtllm_quiesce(
weights_gms,
shadow_b,
label="Shadow B quiesce",
expected_hash=weights_hash,
)
assert_weights_published_once(weights_gms.get_event_history().events)
# Primary starts RO.
primary = manager.start_engine("primary", read_only_weights=True)
assert_completion_ok(
frontend_port,
"Primary test",
failure_message="Primary inference failed",
success_message="Primary inference OK",
)
primary_mem = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"Primary active",
mem_after_b,
primary_mem,
released_a,
min_fraction=0.6,
)
wait_for_weights_state(
weights_gms,
ServerState.RO,
expected_hash=weights_hash,
min_ro_sessions=1,
)
# Kill primary, resume shadow A immediately (no KV blocking).
_kill_process_group(primary)
resume_result = shadow_a.resume(timeout=180)
assert resume_result["status"] == "ok"
shadow_mem = get_gpu_memory_used()
assert_memory_restored_after_quiesce(
"Shadow A resume",
mem_after_b,
shadow_mem,
released_a,
min_fraction=0.6,
)
wait_for_weights_state(
weights_gms,
ServerState.RO,
expected_hash=weights_hash,
min_ro_sessions=1,
)
assert_weights_published_once(weights_gms.get_event_history().events)
assert_completion_ok(
frontend_port,
"Post failover",
failure_message="Shadow after failover failed",
success_message="Shadow after failover OK",
retry_timeout=30.0,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment