Unverified Commit 0028cdf4 authored by Kyle McGill's avatar Kyle McGill Committed by GitHub
Browse files

feat: Use ForwardPassCallback api from TRTLLM to register end of forward pass...


feat: Use ForwardPassCallback api from TRTLLM to register end of forward pass callback to enable cuda graphs (#3297)
Signed-off-by: default avatarKyle McGill <kmcgill@nvidia.com>
parent cfc1e6ce
......@@ -39,6 +39,8 @@ pub trait Worker: Send + Sync {
fn start_load_kv(&mut self) -> anyhow::Result<()>;
fn execute_offload_operations(&mut self) -> anyhow::Result<()>;
fn save_kv_layer(&mut self, layer_idx: usize) -> anyhow::Result<()>;
fn get_finished(
......@@ -215,16 +217,24 @@ impl Worker for KvConnectorWorker {
Ok(())
}
// Assumes the operations are in a valid state for offloading.
fn execute_offload_operations(&mut self) -> anyhow::Result<()> {
let offloading_operations = std::mem::take(&mut self.offloading_operations);
for operation in offloading_operations {
self.connector.enqueue_request(operation);
}
Ok(())
}
fn save_kv_layer(&mut self, _layer_idx: usize) -> anyhow::Result<()> {
self.layers_complete += 1;
if self.layers_complete == self.layer_events.len() {
let offloading_operations = std::mem::take(&mut self.offloading_operations);
// block on the the completion of the last layer
// todo(ryan): capture the context, pass this to the scheduler to do the await on another thread
// or put the event on a stream and use stream waits to keep it all on device.
event_sync_blocking(self.layer_events[self.layers_complete - 1]);
for operation in offloading_operations {
self.connector.enqueue_request(operation);
if let Err(e) = self.execute_offload_operations() {
tracing::error!("Failed to execute offload operations: {}", e);
}
}
Ok(())
......@@ -431,6 +441,12 @@ impl PyTrtllmKvConnectorWorker {
.map_err(to_pyerr)
}
pub fn execute_offload_operations(&mut self) -> PyResult<()> {
self.connector_worker
.execute_offload_operations()
.map_err(to_pyerr)
}
pub fn save_kv_layer(&mut self, layer_idx: usize) -> PyResult<()> {
self.connector_worker
.save_kv_layer(layer_idx)
......
......@@ -13,6 +13,21 @@ from dynamo.runtime import DistributedRuntime
class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
def _callable_object(self) -> callable:
assert (
self._connector is not None
), "Expected cache connector worker to have non-None _connector obj"
assert (
self.event is not None
), "Expected cache connector worker to have non-None event obj"
def callback():
self.event.record()
self.event.synchronize()
self._connector.execute_offload_operations()
return callback
def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args)
......@@ -22,6 +37,18 @@ class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
self.rank = mappings.rank
self._connector = RustKvConnectorWorker(self.drt, str(self.rank))
self.event = torch.cuda.Event()
# Default to old way of processing offload
self.use_forward_pass_callable = False
def register_forward_pass_callable(self) -> callable:
"""
Register a callable object which will be called at the
end of the forward pass.
"""
self.use_forward_pass_callable = True
return self._callable_object()
def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
"""
......@@ -30,7 +57,6 @@ class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
Args:
kv_cache_tensor: The contiguous KV cache tensor.
"""
print(f"Register KV Caches on rank {self.rank}")
logger.info(
f"KvConnectorWorker started registering the kv caches on rank {self.rank}"
)
......@@ -104,8 +130,9 @@ class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
layer_idx: The index of the layer to save.
stream: The stream the forward pass is being executed on.
"""
self.events[layer_idx].record(stream)
self._connector.save_kv_layer(layer_idx)
if not self.use_forward_pass_callable:
self.events[layer_idx].record(stream)
self._connector.save_kv_layer(layer_idx)
def get_finished(
self, finished_gen_req_ids: list[int], started_loading_req_ids: list[int]
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
backend: pytorch
cuda_graph_config:
max_batch_size: 8
kv_cache_config:
enable_partial_reuse: false
free_gpu_memory_fraction: 0.80
max_tokens: 8192
kv_connector_config:
connector_module: dynamo.llm.trtllm_integration.connector
connector_scheduler_class: DynamoKVBMConnectorLeader
connector_worker_class: DynamoKVBMConnectorWorker
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
backend: pytorch
cuda_graph_config: null
kv_cache_config:
enable_partial_reuse: false
free_gpu_memory_fraction: 0.80
max_tokens: 8192
kv_connector_config:
connector_module: dynamo.llm.trtllm_integration.connector
connector_scheduler_class: DynamoKVBMConnectorLeader
connector_worker_class: DynamoKVBMConnectorWorker
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Determinism test for language model API using pytest.
This test suite checks if the model produces deterministic outputs
when given the same inputs with fixed seed and temperature=0.
The test uses comprehensive server warmup (sending all test prompts
before validation) to avoid server initialization effects that could
impact determinism measurements.
"""
import logging
import os
import shutil
import pytest
import requests
from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from tests.utils.payloads import check_models_api
logger = logging.getLogger(__name__)
# Just need a model to show the config works rather than any stress of the system.
MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
SERVED_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
PROMPT = "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with TRTLLM backend"""
def __init__(self, request, worker_id: str, engine_config: str):
self.worker_id = worker_id
command = [
"python3",
"-m",
"dynamo.trtllm",
"--model",
MODEL_PATH,
"--served-model-name",
SERVED_MODEL_NAME,
"--extra-engine-args",
engine_config,
]
# Set debug logging environment
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = "9345"
env["DYN_KVBM_CPU_CACHE_GB"] = "20"
env["DYN_KVBM_DISK_CACHE_GB"] = "60"
env["DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS"] = "1200"
# TODO: Have the managed process take a command name explicitly to distinguish
# between processes started with the same command.
log_dir = f"{request.node.name}_{worker_id}"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
env=env,
health_check_urls=[
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
("http://localhost:9345/health", self.is_ready),
],
timeout=300,
display_output=True,
terminate_existing=False,
log_dir=log_dir,
)
def get_pid(self) -> int | None:
"""Get the PID of the worker process"""
return self.proc.pid if hasattr(self, "proc") and self.proc else None
def is_ready(self, response) -> bool:
"""Check the health of the worker process"""
try:
data = response.json()
if data.get("status") == "ready":
logger.info(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is ready"
)
return True
logger.warning(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is not ready: {data.get('status')}"
)
except ValueError:
logger.warning(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} health response is not valid JSON"
)
return False
def send_completion_request(
prompt: str, max_tokens: int, timeout: int = 120
) -> requests.Response:
"""Send a completion request to the frontend"""
payload = {
"model": SERVED_MODEL_NAME,
"prompt": prompt,
"stream": False,
"max_tokens": max_tokens,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
)
try:
response = requests.post(
"http://localhost:8000/v1/completions",
headers=headers,
json=payload,
timeout=timeout,
)
return response
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {timeout} seconds")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed with error: {e}")
raise
# Test markers to align with repository conventions
# Todo: enable the rest when kvbm is built in the ci
@pytest.mark.kvbm
@pytest.mark.trtllm_marker
@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.gpu_1
@pytest.mark.skip(
reason="Enable these tests once `main` dynamo upgrades to TRTLLM 1.2+"
)
def test_kvbm_without_cuda_graph_enabled(request, runtime_services):
"""
End-to-end test for TRTLLM worker with cuda_graph_config not defined and
KVBM enabled.
This test verifies a TRTLLM worker is able to serve requests when
cuda graphs are not enabled in pytorch. KVBM should be able to offload
blocks regardless.
"""
logger.info("Starting frontend...")
with DynamoFrontendProcess(request):
logger.info("Frontend started.")
engine_config_with_cuda_graph_and_kvbm = (
"tests/kvbm/engine_config_without_cuda_graph_and_kvbm.yaml"
)
logger.info("Starting worker...")
with DynamoWorkerProcess(
request, "decode", engine_config_with_cuda_graph_and_kvbm
) as worker:
logger.info(f"Worker PID: {worker.get_pid()}")
response = send_completion_request(PROMPT, 100, timeout=10)
assert (
response.ok
), f"Expected successful status, got {response.status_code}"
logger.info(f"Completion request succeeded: {response.status_code}")
@pytest.mark.kvbm
@pytest.mark.trtllm_marker
@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.gpu_1
@pytest.mark.skip(
reason="Enable these tests once dynamo `main` upgrades to TRTLLM 1.2+"
)
def test_kvbm_with_cuda_graph_enabled(request, runtime_services):
"""
End-to-end test for TRTLLM worker with cuda_graph_config defined and
KVBM enabled.
This test verifies a TRTLLM worker is able to serve requests when
cuda graphs are enabled in pytorch. KVBM should be able to offload
blocks regardless.
"""
logger.info("Starting frontend...")
with DynamoFrontendProcess(request):
logger.info("Frontend started.")
engine_config_with_cuda_graph_and_kvbm = (
"tests/kvbm/engine_config_with_cuda_graph_and_kvbm.yaml"
)
logger.info("Starting worker...")
with DynamoWorkerProcess(
request, "decode", engine_config_with_cuda_graph_and_kvbm
) as worker:
logger.info(f"Worker PID: {worker.get_pid()}")
response = send_completion_request(PROMPT, 100, timeout=10)
assert (
response.ok
), f"Expected successful status, got {response.status_code}"
logger.info(f"Completion request succeeded: {response.status_code}")
......@@ -568,6 +568,36 @@ class ManagedProcess:
return []
class DynamoFrontendProcess(ManagedProcess):
"""Process manager for Dynamo frontend"""
_logger = logging.getLogger()
def __init__(self, request):
command = ["python", "-m", "dynamo.frontend", "--router-mode", "round-robin"]
log_dir = f"{request.node.name}_frontend"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
self._logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
display_output=True,
terminate_existing=True,
log_dir=log_dir,
)
def get_pid(self) -> int | None:
"""Get the PID of the worker process"""
return self.proc.pid if self.proc else None
def main():
with ManagedProcess(
command=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment