Unverified Commit ac53c0bb authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: Enable disagg support in trtllm standalone script (#1355)

parent 43991e76
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This is a sample config for TensorRT-LLM engine.
# The config provides smaller free_gpu_memory_fraction to ensure that the engine
# does not use all the GPU memory and both prefill and decode workers can fit in
# the GPU memory when running in disaggregated mode.
# You might have to tweak this config based on your model size and GPU memory.
backend: pytorch
kv_cache_config:
free_gpu_memory_fraction: 0.40
...@@ -7,18 +7,28 @@ ...@@ -7,18 +7,28 @@
# #
# `dynamo-run out=trtllm` runs this script # `dynamo-run out=trtllm` runs this script
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params # Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
#
# Disaggregated serving:
# - Ingress: dynamo run in=http out=dyn
# - Decode Worker: python3 trtllm_inc.py --task=decode --extra-engine-args=trtllm_config/sample.yaml
# - Prefill Worker: python3 trtllm_inc.py --task=prefill --extra-engine-args=trtllm_config/sample.yaml
import argparse import argparse
import asyncio import asyncio
import base64
import copy
import logging import logging
import sys import sys
import warnings import warnings
from dataclasses import asdict, dataclass
from typing import Optional from typing import Optional
import uvloop import uvloop
# Import TRTLLM and related modules # Import TRTLLM and related modules
from tensorrt_llm import SamplingParams from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
...@@ -34,6 +44,8 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker ...@@ -34,6 +44,8 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet. # Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Default endpoint for the remote prefill service.
DEFAULT_PREFILL_ENDPOINT = "dyn://dynamo.prefill.generate"
# Default buffer size for kv cache events. # Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024 DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
...@@ -41,6 +53,64 @@ DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024 ...@@ -41,6 +53,64 @@ DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
raise ValueError(
f"Invalid endpoint format: '{endpoint}'. "
"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
return tuple(endpoint_parts)
class DisaggregatedParamsCodec:
"""
Codec for encoding and decoding disaggregated params for network transfer.
"""
@staticmethod
def decode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
opaque_state = (
base64.b64decode(disaggregated_params.opaque_state)
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
@staticmethod
def encode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
encoded_opaque_state = (
base64.b64encode(disaggregated_params.opaque_state).decode("utf-8")
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=encoded_opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
class Config: class Config:
"""Command line parameters or defaults""" """Command line parameters or defaults"""
...@@ -53,6 +123,37 @@ class Config: ...@@ -53,6 +123,37 @@ class Config:
kv_block_size: int kv_block_size: int
extra_engine_args: str extra_engine_args: str
publish_events_and_metrics: bool publish_events_and_metrics: bool
disaggregation_mode: str
remote_prefill_endpoint: str
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"model_name={self.model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"remote_prefill_endpoint={self.remote_prefill_endpoint})"
)
@dataclass
class RequestHandlerConfig:
"""
Configuration for the request handler
"""
component: object
engine: object
default_sampling_params: object
publisher: object
disaggregation_mode: str
remote_prefill_client: object
class RequestHandler: class RequestHandler:
...@@ -60,23 +161,113 @@ class RequestHandler: ...@@ -60,23 +161,113 @@ class RequestHandler:
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
def __init__(self, component, engine, default_sampling_params, publishers): def __init__(self, config: RequestHandlerConfig):
self.engine = engine self.engine = config.engine
self.component = component self.component = config.component
self.default_sampling_params = default_sampling_params self.default_sampling_params = config.default_sampling_params
self.publishers = publishers self.publisher = config.publisher
self.disaggregation_mode = config.disaggregation_mode
self.remote_prefill_client = config.remote_prefill_client
self.first_generation = True self.first_generation = True
async def remote_prefill(self, request):
"""
Send a prefill request to the remote prefill worker.
Args:
request: The original request to be sent for prefill
Returns:
The response from the remote prefill worker
Raises:
ValueError: If prefill client is not initialized or multiple responses received
"""
prefill_request = copy.deepcopy(request)
# TRTLLM requires max_tokens to be set for prefill requests.
prefill_request["stop_conditions"]["max_tokens"] = 1
# Set the disaggregated params to context_only for remote prefill
prefill_request["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(
DisaggregatedParams(request_type="context_only")
)
)
if self.remote_prefill_client is None:
raise ValueError("Prefill client not initialized")
try:
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
remote_prefill_responses = [
remote_prefill_response
async for remote_prefill_response in await self.remote_prefill_client.round_robin(
prefill_request
)
]
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
if len(remote_prefill_responses) > 1:
raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
if len(remote_prefill_responses) == 0:
raise ValueError("No response received from remote prefill worker")
remote_prefill_response = remote_prefill_responses[0]
return remote_prefill_response
async def generate(self, request): async def generate(self, request):
# Check if there is an error in the publishers error queue # Check if there is an error in the publisher error queue
publishers_error = ( publishers_error = (
self.publishers.check_error_queue() if self.publishers else None self.publisher.check_error_queue() if self.publisher else None
) )
if publishers_error: if publishers_error:
raise publishers_error raise publishers_error
inputs = request["token_ids"] inputs = request["token_ids"]
# Decode the disaggregated params from the request
if "disaggregated_params" in request:
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**request["disaggregated_params"])
)
else:
disaggregated_params = None
num_output_tokens_so_far = 0
if self.disaggregation_mode == "decode":
# Run prefill/context phase remotely if disaggregation mode is decode.
try:
prefill_result = await self.remote_prefill(request)
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
remote_prefill_response = prefill_result.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
# Decode the disaggregated params from the remote prefill response
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**remote_prefill_response["disaggregated_params"])
)
# Send the first token response to the client
first_token_response = remote_prefill_response
first_token_response.pop("disaggregated_params")
yield first_token_response
# Set the disaggregated params to generation_only for the rest of the generation
disaggregated_params.request_type = "generation_only"
sampling_params = self.default_sampling_params sampling_params = self.default_sampling_params
for key, value in request["sampling_options"].items(): for key, value in request["sampling_options"].items():
if not value: if not value:
...@@ -88,18 +279,20 @@ class RequestHandler: ...@@ -88,18 +279,20 @@ class RequestHandler:
if max_tokens: if max_tokens:
sampling_params.max_tokens = max_tokens sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0
# TODO: Disable streaming for context only requests when adding disagg support # TODO: Disable streaming for context only requests when adding disagg support
async for res in self.engine.llm.generate_async( async for res in self.engine.llm.generate_async(
inputs=inputs, sampling_params=sampling_params, streaming=True inputs=inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=(self.disaggregation_mode != "prefill"),
): ):
# TRTLLM engine needs to start generating tokens first before stats # TRTLLM engine needs to start generating tokens first before stats
# can be retrieved. # can be retrieved.
if self.first_generation and self.publishers: if self.first_generation and self.publisher:
self.publishers.start() self.publisher.start()
self.first_generation = False self.first_generation = False
if res.finished: if res.finished and self.disaggregation_mode != "prefill":
yield {"finish_reason": "stop", "token_ids": []} yield {"finish_reason": "stop", "token_ids": []}
break break
...@@ -114,6 +307,11 @@ class RequestHandler: ...@@ -114,6 +307,11 @@ class RequestHandler:
out["finish_reason"] = output.finish_reason out["finish_reason"] = output.finish_reason
if output.stop_reason: if output.stop_reason:
out["stop_reason"] = output.stop_reason out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == "prefill":
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
yield out yield out
num_output_tokens_so_far = next_total_toks num_output_tokens_so_far = next_total_toks
...@@ -127,6 +325,24 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -127,6 +325,24 @@ async def init(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
""" """
logging.info(f"Initializing the worker with config: {config}")
remote_prefill_client = None
if config.disaggregation_mode == "decode":
logging.info(
f"Initializing remote prefill client for endpoint: {config.remote_prefill_endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.remote_prefill_endpoint
)
remote_prefill_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
await component.create_service() await component.create_service()
...@@ -177,16 +393,36 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -177,16 +393,36 @@ async def init(runtime: DistributedRuntime, config: Config):
async with get_tensorrtllm_engine(engine_args) as engine: async with get_tensorrtllm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, if config.disaggregation_mode != "prefill":
endpoint, # Register the model with the endpoint if disaggregation mode is not prefill.
config.model_path, # Prefill worker will get the request directly from the Decode worker and not
config.model_name, # through the ingress.
kv_cache_block_size=config.kv_block_size, # FIXME: Enable publishing events and metrics for disaggregated prefill.
# Currently prefill workers are chosen in round-robin fashion.
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
)
# publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig(
component=component,
engine=engine,
default_sampling_params=default_sampling_params,
publisher=None,
disaggregation_mode=config.disaggregation_mode,
remote_prefill_client=remote_prefill_client,
) )
if config.publish_events_and_metrics: if (
# Initialize and pass in the publishers to the request handler to config.publish_events_and_metrics
and config.disaggregation_mode != "prefill"
):
# Initialize and pass in the publisher to the request handler to
# publish events and metrics. # publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component( kv_listener = runtime.namespace(config.namespace).component(
config.component config.component
...@@ -198,13 +434,11 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -198,13 +434,11 @@ async def init(runtime: DistributedRuntime, config: Config):
int(endpoint.lease_id()), int(endpoint.lease_id()),
config.kv_block_size, config.kv_block_size,
) as publisher: ) as publisher:
handler = RequestHandler( handler_config.publisher = publisher
component, engine, default_sampling_params, publisher handler = RequestHandler(handler_config)
)
await endpoint.serve_endpoint(handler.generate) await endpoint.serve_endpoint(handler.generate)
else: else:
# No publishers, so just pass in None to the request handler. handler = RequestHandler(handler_config)
handler = RequestHandler(component, engine, default_sampling_params, None)
await endpoint.serve_endpoint(handler.generate) await endpoint.serve_endpoint(handler.generate)
...@@ -253,16 +487,61 @@ def cmd_line_args(): ...@@ -253,16 +487,61 @@ def cmd_line_args():
parser.add_argument( parser.add_argument(
"--publish-events-and-metrics", "--publish-events-and-metrics",
action="store_true", action="store_true",
help="Publish events and metrics to the dynamo components.", help="Publish events and metrics to the dynamo components. Note: This is not supported when running in prefill disaggregation mode.",
)
parser.add_argument(
"--task",
type=str,
action="append",
choices=["prefill", "decode", "prefill_and_decode"],
default=[],
help="Specifies the task for the engine. Can be specified multiple time for different tasks. Will raise an error if conflicting tasks are specified.",
)
parser.add_argument(
"--remote-prefill-endpoint",
type=str,
default=DEFAULT_PREFILL_ENDPOINT,
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send prefill requests to when running in decode disaggregation mode. Default: {DEFAULT_PREFILL_ENDPOINT}",
) )
args = parser.parse_args() args = parser.parse_args()
# Validate arguments
if args.context_length is not None: if args.context_length is not None:
warnings.warn( warnings.warn(
"--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.", "--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
UserWarning, UserWarning,
) )
endpoint = args.endpoint
# disaggregation mode
disaggregation_mode = None
for choice in ["prefill", "decode", "prefill_and_decode"]:
if choice in args.task:
if disaggregation_mode is not None:
raise ValueError(
f"Conflicting tasks specified: {args.task}. Please specify only one task."
)
disaggregation_mode = choice
if disaggregation_mode is None:
disaggregation_mode = "prefill_and_decode"
if disaggregation_mode == "prefill":
if args.remote_prefill_endpoint != DEFAULT_PREFILL_ENDPOINT:
logging.error(
"--remote-prefill-endpoint is not supported when running in prefill disaggregation mode."
)
sys.exit(1)
else:
endpoint = DEFAULT_PREFILL_ENDPOINT
if args.publish_events_and_metrics:
warnings.warn(
"--publish-events-and-metrics is not supported when running in prefill disaggregation mode.",
UserWarning,
)
config = Config() config = Config()
config.model_path = args.model_path config.model_path = args.model_path
if args.model_name: if args.model_name:
...@@ -271,15 +550,9 @@ def cmd_line_args(): ...@@ -271,15 +550,9 @@ def cmd_line_args():
# This becomes an `Option` on the Rust side # This becomes an `Option` on the Rust side
config.model_name = None config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1) parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint_parts = endpoint_str.split(".") endpoint
if len(endpoint_parts) != 3: )
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace config.namespace = parsed_namespace
config.component = parsed_component_name config.component = parsed_component_name
...@@ -288,6 +561,8 @@ def cmd_line_args(): ...@@ -288,6 +561,8 @@ def cmd_line_args():
config.kv_block_size = args.kv_block_size config.kv_block_size = args.kv_block_size
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics config.publish_events_and_metrics = args.publish_events_and_metrics
config.disaggregation_mode = disaggregation_mode
config.remote_prefill_endpoint = args.remote_prefill_endpoint
return config return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment