worker.py 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
17
import json
18
import os
19
import signal
20
21

import uvloop
22
from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig
23
24
25
26
27
28
29
30
31
from common.disagg_processor import ChatProcessor, parse_chat_message_content
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import merge_promises
from common.protocol import (
    DisaggChatCompletionRequest,
    DisaggChatCompletionStreamResponse,
    DisaggCompletionStreamResponse,
    DisaggregatedTypeConverter,
)
32
33
34
from mpi4py.futures import MPICommExecutor
from mpi4py.MPI import COMM_WORLD
from tensorrt_llm._utils import set_mpi_comm
35
36
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import MpiCommSession
37
38
39
40
41
42
43
from tensorrt_llm.llmapi.disagg_utils import (
    CtxGenServerConfig,
    DisaggServerConfig,
    parse_disagg_config_file,
    split_world_comm,
)
from tensorrt_llm.logger import logger
44
45
from tensorrt_llm.serve.openai_protocol import CompletionRequest

46
from dynamo.llm import KvMetricsPublisher
Neelay Shah's avatar
Neelay Shah committed
47
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
48
49

logger.set_level("debug")
50
51


52
53
54
55
56
57
58
59
def update_args_from_disagg_config(
    engine_config: LLMAPIConfig, server_config: CtxGenServerConfig
):
    # Overwrite the LLM API config with the disaggregated config
    # Allows for different configs for context and generation servers
    engine_config.extra_args.update(**server_config.other_args)
    engine_config.update_sub_configs(server_config.other_args)
    return engine_config
60
61


62
class TensorrtLLMEngine(BaseTensorrtLLMEngine):
63
64
65
66
67
68
    """
    Request handler for the generate endpoint
    """

    def __init__(
        self,
69
        trt_llm_engine_config: TensorrtLLMEngineConfig,
70
71
72
73
74
75
76
77
78
        disagg_config: DisaggServerConfig,
        instance_idx: int,
        sub_comm,
    ):
        self.disagg_config = disagg_config
        self.instance_idx = instance_idx
        self.server_config: CtxGenServerConfig = disagg_config.server_configs[
            instance_idx
        ]
79
        engine_config = update_args_from_disagg_config(
80
            trt_llm_engine_config.engine_config, self.server_config
81
        )
82
        trt_llm_engine_config.engine_config = engine_config
83

84
85
        # needed for disagg
        self._mpi_session = MpiCommSession(sub_comm, n_workers=sub_comm.Get_size())
86
87
88
89
90
        trt_llm_engine_config.engine_config.extra_args[
            "_mpi_session"
        ] = self._mpi_session

        super().__init__(trt_llm_engine_config)
91

Neelay Shah's avatar
Neelay Shah committed
92
    @dynamo_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
93
94
95
    async def generate_chat(self, request):
        if self._llm_engine is None:
            raise RuntimeError("Engine not initialized")
96

97
98
99
100
101
        # Check if there are any errors in the error queue.
        if self._error_queue.qsize() > 0:
            error = self._error_queue.get()
            raise error

102
103
        logger.debug(f"Received request: {request}")
        chat_processor = ChatProcessor(self._model, self._tokenizer, request)
104

105
        self._ongoing_request_count += 1
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        try:
            conversation = []
            for message in request.messages:
                conversation.extend(parse_chat_message_content(message))
            tool_dicts = (
                None
                if request.tools is None
                else [tool.model_dump() for tool in request.tools]
            )
            prompt: str = self._tokenizer.apply_chat_template(
                conversation=conversation,
                tokenize=False,
                add_generation_prompt=request.add_generation_prompt,
                tools=tool_dicts,
                documents=request.documents,
                chat_template=request.chat_template,
                **(request.chat_template_kwargs or {}),
            )
            sampling_params = request.to_sampling_params()
            disaggregated_params = (
                DisaggregatedTypeConverter.to_llm_disaggregated_params(
                    request.disaggregated_params
                )
            )

            final_result = None
            async for result in self._llm_engine.generate_async(
                prompt,
                sampling_params,
                streaming=request.stream,
                disaggregated_params=disaggregated_params,
            ):
                final_result = result
                logger.debug(f"Generated result: {result}")
                if self.server_config.type == "ctx":
                    disaggregated_response = chat_processor.get_chat_stream_response(
                        request.id,
                        result,
                        first_iteration=True,
                    )
                    disaggregated_response.disaggregated_params = (
                        DisaggregatedTypeConverter.to_oai_disaggregated_params(
                            result.outputs[0].disaggregated_params
150
151
                        )
                    )
152
153
154
155
156
157
158
159
160
                    yield disaggregated_response.model_dump_json()
                else:
                    yield chat_processor.get_chat_stream_response(
                        request.id,
                        result,
                        first_iteration=False,
                    ).model_dump_json(
                        exclude_unset=True, exclude={"disaggregated_params"}
                    )
161

162
163
164
165
166
            if request.stream_options and request.stream_options.include_usage:
                yield chat_processor.create_final_stream_response(
                    request.id,
                    final_result,
                ).model_dump_json(exclude_unset=True, exclude={"disaggregated_params"})
167

168
169
170
        except CppExecutorError:
            # If internal executor error is raised, shutdown the server
            signal.raise_signal(signal.SIGINT)
171
        except Exception as e:
172
173
            raise RuntimeError("Failed to generate: " + str(e))

174
175
176
177
178
179
180
181
182
183
184
        # Start the publishing threads with first request submission
        self._stats_loop = asyncio.get_running_loop()
        if (
            self.publish_kv_cache_events_thread
            and not self.publish_kv_cache_events_thread.is_alive()
        ):
            self.publish_kv_cache_events_thread.start()

        if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
            self.publish_stats_thread.start()

185
186
        self._ongoing_request_count -= 1

Neelay Shah's avatar
Neelay Shah committed
187
    @dynamo_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
188
    async def generate_completions(self, request):
189
        logger.debug(f"[worker] worker_id: {self._worker_id} received request")
190
191
192
        if self._llm_engine is None:
            raise RuntimeError("Engine not initialized")

193
194
195
196
197
        # Check if there are any errors in the error queue.
        if self._error_queue.qsize() > 0:
            error = self._error_queue.get()
            raise error

198
        self._ongoing_request_count += 1
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        logger.debug(f"[worker] Received completions request: {request}")

        if not isinstance(request.prompt, str):
            # Check if it's a list and contains integers
            if isinstance(request.prompt, list) and len(request.prompt) == 1:
                request.prompt = request.prompt[0]
            elif not isinstance(request.prompt, list) or not all(
                isinstance(x, int) for x in request.prompt
            ):
                raise ValueError(
                    "Disaggregated server currently only supports single string prompt or list of integers in request"
                )

        sampling_params = request.to_sampling_params()
        llm_disaggregated_params = (
            DisaggregatedTypeConverter.to_llm_disaggregated_params(
                request.disaggregated_params
216
            )
217
        )
218

219
220
        # only 1 prompt is supported for now
        promise = self._llm_engine.generate_async(
221
222
            request.prompt,
            sampling_params,
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            streaming=request.stream,
            disaggregated_params=llm_disaggregated_params,
        )
        generator = merge_promises([promise])
        num_choices = 1 if request.n is None else request.n
        if request.stream:
            response_generator = self.completions_processor.create_completion_generator(
                request, generator, num_choices
            )
            async for response in response_generator:
                yield json.loads(response)
        else:
            raise RuntimeError("Non-streaming is not supported")

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        # Start the publishing threads with first request submission
        if (
            self.publish_kv_cache_events_thread
            and not self.publish_kv_cache_events_thread.is_alive()
        ):
            # [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
            self._stats_loop = asyncio.get_running_loop()
            self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
            self.publish_kv_cache_events_thread.start()

        if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
            self._stats_loop = asyncio.get_running_loop()
            self.publish_stats_thread.set_loop(self._stats_loop)
            self.publish_stats_thread.start()

252
253
254
        self._ongoing_request_count -= 1


Neelay Shah's avatar
Neelay Shah committed
255
@dynamo_worker()
256
257
async def worker(
    runtime: DistributedRuntime,
258
    engine_config: LLMAPIConfig,
259
260
261
    disagg_config: DisaggServerConfig,
    instance_idx: int,
    sub_comm,
262
263
    publish_stats: bool,
    publish_kv_cache_events: bool,
264
265
266
267
268
269
270
271
):
    """
    Instantiate a `backend` component and serve the `generate` endpoint
    A `Component` can serve multiple endpoints
    """
    server_type = disagg_config.server_configs[instance_idx].type
    logger.info(f"Starting {server_type} server")

272
273
274
275
    namespace_str = "dynamo"
    component_str = f"tensorrt-llm-{server_type}"

    component = runtime.namespace(namespace_str).component(component_str)
276
277
    await component.create_service()

278
279
    completions_endpoint = component.endpoint("completions")
    chat_endpoint = component.endpoint("chat/completions")
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

    if server_type == "gen":
        if publish_stats:
            logger.warning("Stats can only be published for ctx server")
            publish_stats = False
        if publish_kv_cache_events:
            logger.warning("KV cache events can only be published for ctx server")
            publish_kv_cache_events = False

    trt_llm_engine_config = TensorrtLLMEngineConfig(
        namespace_str=namespace_str,
        component_str=component_str,
        engine_config=engine_config,
        publish_stats=publish_stats,
        publish_kv_cache_events=publish_kv_cache_events,
    )

    # NOTE: Current implementation adds two endpoints. We can refactor this code to expose only one endpoint.
    # and handle both completions and chat in the same endpoint.
    # Currently, we are using completions endpoint lease id as worker id.
    # I believe this might cause some issues using smart routing with chat completions endpoint.
    trt_llm_engine_config.worker_id = completions_endpoint.lease_id()

    if publish_stats:
        trt_llm_engine_config.kv_metrics_publisher = KvMetricsPublisher()

    engine = TensorrtLLMEngine(
        trt_llm_engine_config,
        disagg_config,
        instance_idx,
        sub_comm,
    )

    coros = [
314
315
        completions_endpoint.serve_endpoint(engine.generate_completions),
        chat_endpoint.serve_endpoint(engine.generate_chat),
316
317
318
319
320
321
322
    ]
    if publish_stats:
        coros.append(
            trt_llm_engine_config.kv_metrics_publisher.create_endpoint(component)
        )

    await asyncio.gather(*coros)
323
324
325
326


if __name__ == "__main__":
    uvloop.install()
327
    args, engine_config = parse_tensorrt_llm_args()
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

    if args.llmapi_disaggregated_config is None or not os.path.exists(
        args.llmapi_disaggregated_config
    ):
        raise ValueError(
            "llmapi_disaggregated_config file does not exist or not provided"
        )

    disagg_config: DisaggServerConfig = parse_disagg_config_file(
        args.llmapi_disaggregated_config
    )

    logger.info(f"Parsed disaggregated config: {disagg_config}")

    is_leader, instance_idx, sub_comm = split_world_comm(disagg_config.server_configs)
    os.environ["TRTLLM_USE_MPI_KVCACHE"] = "1"
    set_mpi_comm(sub_comm)

    logger.info(f"is_leader: {is_leader}, instance_idx: {instance_idx}")

    if is_leader:
349
350
351
352
353
354
355
356
357
358
        asyncio.run(
            worker(
                engine_config,
                disagg_config,
                instance_idx,
                sub_comm,
                args.publish_stats,
                args.publish_kv_cache_events,
            )
        )
359
360
361
362
    else:
        with MPICommExecutor(sub_comm) as executor:
            if not is_leader and executor is not None:
                raise RuntimeError(f"rank{COMM_WORLD} should not have executor")