worker.py 8.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
SGLang disaggregated serving flow is

Processor -> PrefillWorker -> DecodeWorker

This is different from how we've implemented the vLLM disaggregated flow.

For now - the SGLangWorker will be responsible for aggreagted and prefill and we will
have a separate DecodeWorker.
"""

27
import asyncio
28
import logging
29
30
import random
import socket
31
from typing import Dict, Union
32
33

import sglang as sgl
34
35
36
from components.decode_worker import SGLangDecodeWorker
from sglang.srt.utils import get_ip
from utils.protocol import DisaggPreprocessedRequest, PreprocessedRequest
37
38
39
from utils.sglang import parse_sglang_args

from dynamo.llm import ModelType, register_llm
40
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
41
42
43
44
45
46
47
48
49
50
51
52

logger = logging.getLogger(__name__)


@service(
    dynamo={
        "namespace": "dynamo",
    },
    resources={"gpu": 1},
    workers=1,
)
class SGLangWorker:
53
54
    decode_worker = depends(SGLangDecodeWorker)

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def __init__(self):
        class_name = self.__class__.__name__
        self.engine_args = parse_sglang_args(class_name, "")
        self.engine = sgl.Engine(server_args=self.engine_args)

        logger.info("SGLangWorker initialized")

    @async_on_start
    async def async_init(self):
        runtime = dynamo_context["runtime"]
        logger.info("Registering LLM for discovery")
        comp_ns, comp_name = SGLangWorker.dynamo_address()  # type: ignore
        endpoint = runtime.namespace(comp_ns).component(comp_name).endpoint("generate")
        await register_llm(
            ModelType.Backend,
            endpoint,
            self.engine_args.model_path,
            self.engine_args.served_model_name,
        )
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        if self.engine_args.disaggregation_mode:
            self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info()
            comp_ns, comp_name = SGLangDecodeWorker.dynamo_address()  # type: ignore
            self.decode_client = (
                await runtime.namespace(comp_ns)
                .component(comp_name)
                .endpoint("generate")
                .client()
            )

    def _get_bootstrap_info(self):
        """
        Bootstrap info is stored in the worker's tokenizer manager. We use it to
        add servers to the bootstrap_room
        """
        inner_tm = self.engine.tokenizer_manager
        bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port

        # multinode check
        if inner_tm.server_args.dist_init_addr:
            bootstrap_host = socket.gethostbyname(
                inner_tm.server_args.dist_init_addr.split(":")[0]
            )
        else:
            bootstrap_host = get_ip()

        return bootstrap_host, bootstrap_port
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    def _build_sampling_params(self, request: PreprocessedRequest) -> dict:
        # TODO: maintain a full mapping from PreprocessedRequest to SGLang's SamplingParams
        sampling_params = {}
        if request.sampling_options.temperature:
            sampling_params["temperature"] = request.sampling_options.temperature
        if request.sampling_options.top_p:
            sampling_params["top_p"] = request.sampling_options.top_p
        if request.sampling_options.top_k:
            sampling_params["top_k"] = request.sampling_options.top_k
        sampling_params["max_new_tokens"] = request.stop_conditions.max_tokens
        if request.stop_conditions.ignore_eos:
            sampling_params["ignore_eos"] = request.stop_conditions.ignore_eos
        return sampling_params

116
117
118
119
120
121
122
123
124
125
    def _get_request_batch_size(self, request: PreprocessedRequest):
        """Get batch size from request, returns None for single requests"""
        if request.batch_token_ids is not None:
            return len(request.batch_token_ids)
        return None

    def _is_batch_request(self, request: PreprocessedRequest):
        """Check if request is in batch mode"""
        return request.batch_token_ids is not None

126
    @endpoint()
127
    async def generate(self, request: PreprocessedRequest):
128
129
130
131
        # Check if we're in batch mode at the start
        is_batch = self._is_batch_request(request)
        batch_size = self._get_request_batch_size(request)

132
133
        # TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
        sampling_params = self._build_sampling_params(request)
134
135

        if self.engine_args.disaggregation_mode != "null":
136
137
138
139
140
141
142
143
144
145
            if is_batch:
                bootstrap_room = [
                    self._generate_bootstrap_room() for _ in range(batch_size)
                ]
                bootstrap_host = [self.bootstrap_host] * batch_size
                bootstrap_port = [self.bootstrap_port] * batch_size
            else:
                bootstrap_host = self.bootstrap_host
                bootstrap_port = self.bootstrap_port
                bootstrap_room = self._generate_bootstrap_room()
146
147
148
149
150

            # decode worker request
            disagg_request = DisaggPreprocessedRequest(
                request=request,
                sampling_params=sampling_params,
151
152
                bootstrap_host=bootstrap_host,
                bootstrap_port=bootstrap_port,
153
154
155
156
157
                bootstrap_room=bootstrap_room,
            )

            # prefill response is not used
            prefill = await self.engine.async_generate(
158
159
160
                input_ids=request.token_ids
                if not is_batch
                else request.batch_token_ids,
161
162
                sampling_params=sampling_params,
                stream=True,
163
164
                bootstrap_host=bootstrap_host,
                bootstrap_port=bootstrap_port,
165
166
167
168
169
170
                bootstrap_room=bootstrap_room,
            )
            prefill_task = asyncio.create_task(self._prefill_generator(prefill))

            decode = await self.decode_client.generate(disagg_request.model_dump_json())

171
172
173
            async for out in self._process_stream(
                decode, unpack=True, is_batch=is_batch
            ):
174
175
176
177
178
                yield out

            await prefill_task
        else:
            g = await self.engine.async_generate(
179
180
181
                input_ids=request.token_ids
                if not is_batch
                else request.batch_token_ids,
182
183
184
185
                sampling_params=sampling_params,
                stream=True,
            )

186
            async for out in self._process_stream(g, unpack=False, is_batch=is_batch):
187
188
                yield out

189
190
191
192
193
194
195
196
    async def _process_stream(self, stream_source, unpack: bool, is_batch: bool):
        # Initialize based on batch mode
        num_output_tokens_so_far: Union[Dict[int, int], int]
        if is_batch:
            num_output_tokens_so_far = {}
        else:
            num_output_tokens_so_far = 0

197
198
199
        async for res in stream_source:
            data = res.data() if unpack else res
            finish_reason = data["meta_info"]["finish_reason"]
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

            if is_batch:
                # Handle batch response
                assert isinstance(num_output_tokens_so_far, dict)
                index = data.get("index", 0)
                if index not in num_output_tokens_so_far:
                    num_output_tokens_so_far[index] = 0

                if finish_reason:
                    out = {
                        "token_ids": [],
                        "finish_reason": finish_reason["type"],
                        "index": index,
                    }
                else:
                    next_total_toks = len(data["output_ids"])
                    new_tokens = data["output_ids"][num_output_tokens_so_far[index] :]
                    out = {
                        "token_ids": new_tokens,
                        "index": index,
                    }
                    num_output_tokens_so_far[index] = next_total_toks
222
            else:
223
224
225
226
227
228
229
230
231
                # Handle single response
                assert isinstance(num_output_tokens_so_far, int)
                if finish_reason:
                    out = {"token_ids": [], "finish_reason": finish_reason["type"]}
                else:
                    next_total_toks = len(data["output_ids"])
                    out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]}
                    num_output_tokens_so_far = next_total_toks

232
            yield out
233
234
235
236
237
238
239

    def _generate_bootstrap_room(self):
        return random.randint(0, 2**63 - 1)

    async def _prefill_generator(self, prefill):
        async for _ in prefill:
            pass