io_struct.py 12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
20
"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import uuid
22
from dataclasses import dataclass
23
from enum import Enum
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25
from typing import Dict, List, Optional, Union

26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
31


@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
32
    # The input prompt. It can be a single prompt or a batch of prompts.
33
    text: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
34
    # The token ids for text; one can either specify text or input_ids.
35
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Ying Sheng's avatar
Ying Sheng committed
36
37
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
38
    image_data: Optional[Union[List[str], str]] = None
39
    # The sampling_params. See descriptions below.
40
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
41
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
43
    # Whether to return logprobs.
44
    return_logprob: Optional[Union[List[bool], bool]] = None
45
    # If return logprobs, the start location in the prompt for returning logprobs.
46
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
47
    logprob_start_len: Optional[Union[List[int], int]] = None
48
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
49
    top_logprobs_num: Optional[Union[List[int], int]] = None
50
    # Whether to detokenize tokens in text in the returned logprobs.
51
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
52
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    stream: bool = False
54
55
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
56
57
58
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

59
    def normalize_batch_and_arguments(self):
60
61
62
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
63
            raise ValueError("Either text or input_ids should be provided.")
64

65
        # Derive the batch size
66
67
68
69
70
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
71
                self.is_single = False
72
                self.batch_size = len(self.text)
73
        else:
74
75
76
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
77
            else:
78
                self.is_single = False
79
80
                self.batch_size = len(self.input_ids)

81
82
        # Handle parallel sampling
        # When parallel sampling is used, we always treat the input as a batch.
83
84
        if self.sampling_params is None:
            self.parallel_sample_num = 1
85
        elif isinstance(self.sampling_params, dict):
86
87
88
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
89
90
            assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), (
                "The parallel_sample_num should be the same for all samples in sample params.")
Lianmin Zheng's avatar
Lianmin Zheng committed
91

92
93
94
95
96
97
98
99
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

        # Fill in default arguments
100
        if self.is_single:
Lianmin Zheng's avatar
Lianmin Zheng committed
101
102
103
104
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
105
106
107
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
108
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
109
110
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
111
        else:
112
113
            if self.parallel_sample_num == 1:
                num = self.batch_size
114
            else:
115
116
                # Expand parallel_sample_num
                num = self.batch_size * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
117
118
119
120
121

            if self.image_data is None:
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num
122
            elif isinstance(self.image_data, list):
123
                pass
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
126
127
128
129
130
131
132

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
133
                assert isinstance(self.rid, list), "The rid should be a list."
Lianmin Zheng's avatar
Lianmin Zheng committed
134

135
136
137
138
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
139
140
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
141

142
            if self.logprob_start_len is None:
143
                self.logprob_start_len = [-1] * num
144
145
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
146
147
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
148

Liangsheng Yin's avatar
Liangsheng Yin committed
149
150
151
152
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num
153
154
            else:
                assert self.parallel_sample_num == 1
Liangsheng Yin's avatar
Liangsheng Yin committed
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
176
177
178

@dataclass
class TokenizedGenerateReqInput:
179
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
180
    rid: str
181
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
182
    input_text: str
183
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
184
    input_ids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
185
186
    # The image input
    image_inputs: dict
187
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
188
    sampling_params: SamplingParams
189
    # Whether to return the logprobs
190
    return_logprob: bool
191
    # If return logprobs, the start location in the prompt for returning logprobs.
192
    logprob_start_len: int
193
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
194
    top_logprobs_num: int
195
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
196
197
    stream: bool

198
199
200
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model

Lianmin Zheng's avatar
Lianmin Zheng committed
201

202
203
204
205
206
207
208
209
210
211
212
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None

213
    def normalize_batch_and_arguments(self):
214
215
216
217
218
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
            raise ValueError("Either text or input_ids should be provided.")

219
        # Derive the batch size
220
        if self.text is not None:
221
222
223
224
225
226
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.text)
227
        else:
228
229
230
231
232
233
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.input_ids)
234

235
        # Fill in default arguments
236
        if self.is_single:
237
238
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
239
            if self.sampling_params is None:
240
241
                self.sampling_params = {}
            self.sampling_params["max_new_tokens"] = 1
242
243
244
245
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
246
247
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
248
            if self.sampling_params is None:
249
250
251
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
                self.sampling_params[i]["max_new_tokens"] = 1
252

253
254
255
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
256

257
258
259
260
261
262
263
    def __getitem__(self, i):
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
264
265
266


@dataclass
267
class TokenizedEmbeddingReqInput:
268
269
270
271
272
273
274
275
276
277
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
278
279
@dataclass
class BatchTokenIDOut:
280
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
281
    rids: List[str]
282
    # The version id to sync decode status with in detokenizer_manager
283
    vids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
284
    decoded_texts: List[str]
285
286
    decode_ids: List[int]
    read_offsets: List[int]
287
288
    # Only used when `--skip-tokenizer-init`
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
289
    skip_special_tokens: List[bool]
290
    spaces_between_special_tokens: List[bool]
Lianmin Zheng's avatar
Lianmin Zheng committed
291
    meta_info: List[Dict]
292
    finished_reason: List[BaseFinishReason]
293
    no_stop_trim: List[bool]
Lianmin Zheng's avatar
Lianmin Zheng committed
294

Liangsheng Yin's avatar
Liangsheng Yin committed
295

Lianmin Zheng's avatar
Lianmin Zheng committed
296
297
@dataclass
class BatchStrOut:
298
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
299
    rids: List[str]
300
    # The output decoded strings
301
    output_strs: List[str]
302
    # The meta info
Lianmin Zheng's avatar
Lianmin Zheng committed
303
    meta_info: List[Dict]
304
    # The finish reason
305
    finished_reason: List[BaseFinishReason]
Liangsheng Yin's avatar
Liangsheng Yin committed
306
307


308
309
@dataclass
class BatchEmbeddingOut:
310
    # The request id
311
    rids: List[str]
312
    # The output embedding
313
    embeddings: List[List[float]]
314
    # The meta info
315
    meta_info: List[Dict]
316
    # The finish reason
317
318
319
    finished_reason: List[BaseFinishReason]


Liangsheng Yin's avatar
Liangsheng Yin committed
320
321
322
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
323

324

325
326
327
328
329
330
331
332
333
334
335
336
337
338
@dataclass
class UpdateWeightReqInput:
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
class UpdateWeightReqOutput:
    success: bool
    message: str


339
340
@dataclass
class AbortReq:
341
    # The request id
342
    rid: str
343
344
345
346
347


class ProfileReq(Enum):
    START_PROFILE = 1
    STOP_PROFILE = 2
348
349
350
351
352
353
354
355
356
357


@dataclass
class GetMemPoolSizeReq:
    pass


@dataclass
class GetMemPoolSizeReqOutput:
    size: int