trtllm_utils.py 7.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
from typing import Optional

from utils.request_handlers.handler_base import (
    DisaggregationMode,
    DisaggregationStrategy,
)

# Default endpoint for the next worker.
DEFAULT_ENDPOINT = "dyn://dynamo.tensorrt_llm.generate"
DEFAULT_MODEL_PATH = "TinyLlama-1.1B-Instruct"
DEFAULT_NEXT_ENDPOINT = "dyn://dynamo.tensorrt_llm_next.generate"
DEFAULT_DISAGGREGATION_STRATEGY = DisaggregationStrategy.DECODE_FIRST
DEFAULT_DISAGGREGATION_MODE = DisaggregationMode.AGGREGATED


class Config:
    """Command line parameters or defaults"""

    def __init__(self) -> None:
        self.namespace: str = ""
        self.component: str = ""
        self.endpoint: str = ""
        self.model_path: str = ""
        self.served_model_name: Optional[str] = None
        self.tensor_parallel_size: int = 1
        self.kv_block_size: int = 32
        self.extra_engine_args: str = ""
        self.publish_events_and_metrics: bool = False
        self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE
        self.disaggregation_strategy: DisaggregationStrategy = (
            DEFAULT_DISAGGREGATION_STRATEGY
        )
        self.next_endpoint: str = ""

    def __str__(self) -> str:
        return (
            f"Config(namespace={self.namespace}, "
            f"component={self.component}, "
            f"endpoint={self.endpoint}, "
            f"model_path={self.model_path}, "
            f"served_model_name={self.served_model_name}, "
            f"tensor_parallel_size={self.tensor_parallel_size}, "
            f"kv_block_size={self.kv_block_size}, "
            f"extra_engine_args={self.extra_engine_args}, "
            f"publish_events_and_metrics={self.publish_events_and_metrics}, "
            f"disaggregation_mode={self.disaggregation_mode}, "
            f"disaggregation_strategy={self.disaggregation_strategy}, "
            f"next_endpoint={self.next_endpoint})"
        )


def is_first_worker(config):
    """
    Check if the current worker is the first worker in the disaggregation chain.
    """
    is_primary_worker = config.disaggregation_mode == DisaggregationMode.AGGREGATED
    if not is_primary_worker:
        is_primary_worker = (
            config.disaggregation_strategy == DisaggregationStrategy.PREFILL_FIRST
        ) and (config.disaggregation_mode == DisaggregationMode.PREFILL)

    if not is_primary_worker:
        is_primary_worker = (
            config.disaggregation_strategy == DisaggregationStrategy.DECODE_FIRST
        ) and (config.disaggregation_mode == DisaggregationMode.DECODE)

    return is_primary_worker


def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
    endpoint_str = endpoint.replace("dyn://", "", 1)
    endpoint_parts = endpoint_str.split(".")
    if len(endpoint_parts) != 3:
        raise ValueError(
            f"Invalid endpoint format: '{endpoint}'. "
            "Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
        )
    namespace, component, endpoint_name = endpoint_parts
    return namespace, component, endpoint_name


def cmd_line_args():
    parser = argparse.ArgumentParser(
        description="TensorRT-LLM server integrated with Dynamo LLM."
    )
    parser.add_argument(
        "--endpoint",
        type=str,
        default="",
        help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT} if first worker, {DEFAULT_NEXT_ENDPOINT} if next worker",
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default=DEFAULT_MODEL_PATH,
        help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL_PATH}",
    )
    parser.add_argument(
        "--served-model-name",
        type=str,
        default="",
        help="Name to serve the model under. Defaults to deriving it from model path.",
    )
    parser.add_argument(
        "--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
    )
    # IMPORTANT: We should ideally not expose this to users. We should be able to
    # query the block size from the TRTLLM engine.
    parser.add_argument(
        "--kv-block-size", type=int, default=32, help="Size of a KV cache block."
    )

    parser.add_argument(
        "--extra-engine-args",
        type=str,
        default="",
        help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
    )
    parser.add_argument(
        "--publish-events-and-metrics",
        action="store_true",
        help="Publish events and metrics to the dynamo components. Note: This is not supported when running in prefill disaggregation mode.",
    )
    parser.add_argument(
        "--disaggregation-mode",
        type=str,
        default=DEFAULT_DISAGGREGATION_MODE,
        choices=[mode.value for mode in DisaggregationMode],
        help=f"Mode to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_MODE}",
    )
    parser.add_argument(
        "--disaggregation-strategy",
        type=str,
        default=DEFAULT_DISAGGREGATION_STRATEGY,
        choices=[strategy.value for strategy in DisaggregationStrategy],
        help=f"Strategy to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_STRATEGY}",
    )
    parser.add_argument(
        "--next-endpoint",
        type=str,
        default="",
        help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send requests to when running in disaggregation mode. Default: {DEFAULT_NEXT_ENDPOINT} if first worker, empty if next worker",
    )
    args = parser.parse_args()

    config = Config()
    # Set the model path and served model name.
    config.model_path = args.model_path
    if args.served_model_name:
        config.served_model_name = args.served_model_name
    else:
        # This becomes an `Option` on the Rust side
        config.served_model_name = None

    # Set the disaggregation mode and strategy.
    config.disaggregation_mode = DisaggregationMode(args.disaggregation_mode)
    config.disaggregation_strategy = DisaggregationStrategy(
        args.disaggregation_strategy
    )

    # Set the appropriate defaults for the endpoint and next endpoint.
    if is_first_worker(config):
        if args.endpoint == "":
            args.endpoint = DEFAULT_ENDPOINT
        if (
            args.next_endpoint == ""
            and config.disaggregation_mode != DisaggregationMode.AGGREGATED
        ):
            args.next_endpoint = DEFAULT_NEXT_ENDPOINT
    else:
        if args.endpoint == "":
            args.endpoint = DEFAULT_NEXT_ENDPOINT
        if args.next_endpoint != "":
            raise ValueError("Next endpoint is not allowed for the next worker")
    endpoint = args.endpoint
    parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
        endpoint
    )

    config.namespace = parsed_namespace
    config.component = parsed_component_name
    config.endpoint = parsed_endpoint_name
    config.next_endpoint = args.next_endpoint

    config.tensor_parallel_size = args.tensor_parallel_size
    config.kv_block_size = args.kv_block_size
    config.extra_engine_args = args.extra_engine_args
    config.publish_events_and_metrics = args.publish_events_and_metrics

    return config