router.py 6.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import copy
18
import json
19
20

import uvloop
21
22
23
24
from common.protocol import (
    DisaggChatCompletionRequest,
    DisaggChatCompletionStreamResponse,
    DisaggCompletionStreamResponse,
25
26
)
from tensorrt_llm.logger import logger
27
from tensorrt_llm.serve.openai_protocol import CompletionRequest, DisaggregatedParams
28

Neelay Shah's avatar
Neelay Shah committed
29
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
30

31
logger.set_level("debug")
32
33
34


class Router:
35
36
37
38
39
40
41
42
43
44
45
    def __init__(
        self,
        ctx_chat_client,
        gen_chat_client,
        ctx_completion_client,
        gen_completion_client,
    ):
        self.ctx_chat_client = ctx_chat_client
        self.gen_chat_client = gen_chat_client
        self.ctx_completion_client = ctx_completion_client
        self.gen_completion_client = gen_completion_client
46
47
        logger.info("INITIALIZED ROUTER")

48
49
    async def _get_ctx_resp(self, request, ctx_client):
        logger.debug(f"Received request {request}")
50

51
        request.max_completion_tokens = 1
52
53
        request.disaggregated_params = DisaggregatedParams(request_type="context_only")
        logger.debug(f"[router] Sending request to context server: {request}")
54
55
        ctx_resp = [
            resp
56
            async for resp in await ctx_client.round_robin(request.model_dump_json())
57
58
59
60
61
        ]
        if len(ctx_resp) > 1:
            raise ValueError(
                "Context server returned more than one response. This is currently not supported in disaggregated server."
            )
62
63
64
65
        logger.debug(
            f"[router] received response from context server: {ctx_resp[0].data()}"
        )
        return ctx_resp[0].data()
66

67
68
69
70
    # TODO (shreyasm): The only reason we cant further combine the two methods below is
    # because the disagg params are in different locations.
    # Disagg params should be in under the choices field in the response object.
    # This is the case for completions but not for chat.
71

Neelay Shah's avatar
Neelay Shah committed
72
    @dynamo_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    async def generate_completion(self, request):
        # These settings are needed to satisfy request checks.
        request.skip_special_tokens = False
        request.add_special_tokens = False
        request.spaces_between_special_tokens = False

        gen_req = copy.deepcopy(request)

        ctx_resp = await self._get_ctx_resp(request, self.ctx_completion_client)
        ctx_resp_obj = DisaggCompletionStreamResponse.model_validate(ctx_resp)

        gen_req.disaggregated_params = DisaggregatedParams.model_validate(
            ctx_resp_obj.choices[0].disaggregated_params
        )
87
88
        gen_req.disaggregated_params.request_type = "generation_only"

89
90
91
92
93
94
95
96
97
        if request.stream:
            yield json.loads(
                ctx_resp_obj.model_dump_json(
                    exclude_unset=True, exclude={"disaggregated_params"}
                )
            )

        logger.debug(f"[router] Sending request to generation server: {gen_req}")
        async for response in await self.gen_completion_client.round_robin(
98
99
            gen_req.model_dump_json()
        ):
100
101
102
            logger.debug(
                f"[router] Received response from generation server: {response.data()}"
            )
103
104
105
106
107
            gen_resp_obj = DisaggCompletionStreamResponse.model_validate(
                response.data()
            )
            yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))

Neelay Shah's avatar
Neelay Shah committed
108
    @dynamo_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    async def generate_chat(self, request):
        # These settings are needed to satisfy request checks.
        request.skip_special_tokens = False
        request.add_special_tokens = False
        request.spaces_between_special_tokens = False

        gen_req = copy.deepcopy(request)

        ctx_resp = await self._get_ctx_resp(request, self.ctx_chat_client)
        ctx_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(ctx_resp)

        gen_req.disaggregated_params = DisaggregatedParams.model_validate(
            ctx_resp_obj.disaggregated_params
        )
        gen_req.disaggregated_params.request_type = "generation_only"

        if request.stream:
            yield json.loads(
                ctx_resp_obj.model_dump_json(
                    exclude_unset=True, exclude={"disaggregated_params"}
                )
            )

        logger.debug(f"[router] Sending request to generation server: {gen_req}")
        async for response in await self.gen_chat_client.round_robin(
            gen_req.model_dump_json()
        ):
136
137
138
139
            logger.debug(
                f"[router] Received response from generation server: {response.data()}"
            )
            gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(
140
141
142
                response.data()
            )
            yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
143
144


Neelay Shah's avatar
Neelay Shah committed
145
@dynamo_worker()
146
async def worker(runtime: DistributedRuntime):
147
148
149
150
    """
    Instantiate a `backend` component and serve the `generate` endpoint
    A `Component` can serve multiple endpoints
    """
Neelay Shah's avatar
Neelay Shah committed
151
    component = runtime.namespace("dynamo").component("router")
152
153
    await component.create_service()

154
    ctx_completion_client = (
Neelay Shah's avatar
Neelay Shah committed
155
        await runtime.namespace("dynamo")
156
        .component("tensorrt-llm-ctx")
157
        .endpoint("completions")
158
159
        .client()
    )
160
    gen_completion_client = (
Neelay Shah's avatar
Neelay Shah committed
161
        await runtime.namespace("dynamo")
162
        .component("tensorrt-llm-gen")
163
164
165
166
        .endpoint("completions")
        .client()
    )
    ctx_chat_client = (
Neelay Shah's avatar
Neelay Shah committed
167
        await runtime.namespace("dynamo")
168
169
170
171
172
        .component("tensorrt-llm-ctx")
        .endpoint("chat/completions")
        .client()
    )
    gen_chat_client = (
Neelay Shah's avatar
Neelay Shah committed
173
        await runtime.namespace("dynamo")
174
175
        .component("tensorrt-llm-gen")
        .endpoint("chat/completions")
176
177
178
        .client()
    )

179
180
181
182
183
184
185
186
187
    completions_endpoint = component.endpoint("completions")
    chat_endpoint = component.endpoint("chat/completions")
    router = Router(
        ctx_chat_client, gen_chat_client, ctx_completion_client, gen_completion_client
    )
    await asyncio.gather(
        completions_endpoint.serve_endpoint(router.generate_completion),
        chat_endpoint.serve_endpoint(router.generate_chat),
    )
188
189
190
191


if __name__ == "__main__":
    uvloop.install()
192
    asyncio.run(worker())