"src/targets/vscode:/vscode.git/clone" did not exist on "c27768c75aad8d215808a062703e6ebce07246da"
io_struct.py 11.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
20
"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import uuid
22
from dataclasses import dataclass
Lianmin Zheng's avatar
Lianmin Zheng committed
23
24
from typing import Dict, List, Optional, Union

25
from sglang.srt.managers.schedule_batch import BaseFinishReason
26
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
29
30


@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
31
    # The input prompt. It can be a single prompt or a batch of prompts.
32
    text: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
33
    # The token ids for text; one can either specify text or input_ids.
34
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Ying Sheng's avatar
Ying Sheng committed
35
36
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
37
    image_data: Optional[Union[List[str], str]] = None
38
    # The sampling_params. See descriptions below.
39
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
40
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
41
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
42
    # Whether to return logprobs.
43
    return_logprob: Optional[Union[List[bool], bool]] = None
44
    # If return logprobs, the start location in the prompt for returning logprobs.
45
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
46
    logprob_start_len: Optional[Union[List[int], int]] = None
47
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
48
    top_logprobs_num: Optional[Union[List[int], int]] = None
49
    # Whether to detokenize tokens in text in the returned logprobs.
50
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
51
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    stream: bool = False
53
54
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
55
56
57
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
58
    def post_init(self):
59
60
61
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
62
            raise ValueError("Either text or input_ids should be provided.")
63

64
65
66
67
68
69
70
        self.is_single = False
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
                self.batch_size = len(self.text)
71
        else:
72
73
74
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
75
            else:
76
77
78
79
                self.batch_size = len(self.input_ids)

        if self.sampling_params is None:
            self.parallel_sample_num = 1
80
        elif isinstance(self.sampling_params, dict):
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
            for sp in self.sampling_params:
                # TODO cope with the case that the parallel_sample_num is different for different samples
                assert self.parallel_sample_num == sp.get(
                    "n", 1
                ), "The parallel_sample_num should be the same for all samples in sample params."

        if self.parallel_sample_num > 1:
            if self.is_single:
                self.is_single = False
                if self.text is not None:
                    self.text = [self.text]
                if self.input_ids is not None:
                    self.input_ids = [self.input_ids]
Lianmin Zheng's avatar
Lianmin Zheng committed
97

98
        if self.is_single:
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100
101
102
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
103
104
105
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
106
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
107
108
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
109
        else:
110
111
            if self.parallel_sample_num == 1:
                num = self.batch_size
112
            else:
113
114
115
                # FIXME support cascade inference
                # first bs samples are used for caching the prefix for parallel sampling
                num = self.batch_size + self.parallel_sample_num * self.batch_size
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
120

            if self.image_data is None:
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num
121
            elif isinstance(self.image_data, list):
122
                # FIXME incorrect order for duplication
123
                self.image_data = self.image_data * num
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
126
127
128

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num
129
130
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
133
134

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
135
136
                assert isinstance(self.rid, list), "The rid should be a list."
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
137

138
139
140
141
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
142
143
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
144

145
            if self.logprob_start_len is None:
146
                self.logprob_start_len = [-1] * num
147
148
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
149
150
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
151

Liangsheng Yin's avatar
Liangsheng Yin committed
152
153
154
155
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num
156
157
            else:
                assert self.parallel_sample_num == 1
Liangsheng Yin's avatar
Liangsheng Yin committed
158

Lianmin Zheng's avatar
Lianmin Zheng committed
159
160
161

@dataclass
class TokenizedGenerateReqInput:
162
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
163
    rid: str
164
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
165
    input_text: str
166
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
167
    input_ids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
168
169
    # The image input
    image_inputs: dict
170
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
171
    sampling_params: SamplingParams
172
    # Whether to return the logprobs
173
    return_logprob: bool
174
    # If return logprobs, the start location in the prompt for returning logprobs.
175
    logprob_start_len: int
176
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
177
    top_logprobs_num: int
178
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
179
180
    stream: bool

181
182
183
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model

Lianmin Zheng's avatar
Lianmin Zheng committed
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None

    def post_init(self):
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
            raise ValueError("Either text or input_ids should be provided.")

        if self.text is not None:
203
            self.is_single = isinstance(self.text, str)
204
        else:
205
            self.is_single = isinstance(self.input_ids[0], int)
206

207
        if self.is_single:
208
209
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
210
            if self.sampling_params is None:
211
212
                self.sampling_params = {}
            self.sampling_params["max_new_tokens"] = 1
213
214
215
216
217
218
219
220
221
222
        else:
            # support select operation
            self.batch_size = (
                len(self.text) if self.text is not None else len(self.input_ids)
            )
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
                if not isinstance(self.rid, list):
                    raise ValueError("The rid should be a list.")
Ying Sheng's avatar
Ying Sheng committed
223
            if self.sampling_params is None:
224
225
226
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
                self.sampling_params[i]["max_new_tokens"] = 1
227
228
229
230


@dataclass
class TokenizedEmbeddingReqInput:
231
    # The request id
232
    rid: str
233
    # The input text
234
    input_text: str
235
    # The input token ids
236
    input_ids: List[int]
237
    # Dummy sampling params for compatibility
238
239
240
    sampling_params: SamplingParams


241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
@dataclass
class RewardReqInput:
    # The input prompt in the chat format. It can be a single prompt or a batch of prompts.
    conv: Union[List[List[Dict]], List[Dict]]
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None

    def post_init(self):
        self.is_single = isinstance(self.conv[0], dict)

        if self.is_single:
            if self.rid is None:
                self.rid = uuid.uuid4().hex
            if self.sampling_params is None:
                self.sampling_params = {}
            self.sampling_params["max_new_tokens"] = 1
        else:
            # support select operation
            self.batch_size = len(self.conv)
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
                if not isinstance(self.rid, list):
                    raise ValueError("The rid should be a list.")
            if self.sampling_params is None:
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
                self.sampling_params[i]["max_new_tokens"] = 1


@dataclass
class TokenizedRewardReqInput:
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
@dataclass
class BatchTokenIDOut:
287
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
288
    rids: List[str]
289
    # The version id to sync decode status with in detokenizer_manager
290
    vids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
291
    decoded_texts: List[str]
292
293
    decode_ids: List[int]
    read_offsets: List[int]
Lianmin Zheng's avatar
Lianmin Zheng committed
294
    skip_special_tokens: List[bool]
295
    spaces_between_special_tokens: List[bool]
Lianmin Zheng's avatar
Lianmin Zheng committed
296
    meta_info: List[Dict]
297
    finished_reason: List[BaseFinishReason]
Lianmin Zheng's avatar
Lianmin Zheng committed
298

Liangsheng Yin's avatar
Liangsheng Yin committed
299

Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
@dataclass
class BatchStrOut:
302
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
303
    rids: List[str]
304
    # The output decoded strings
305
    output_strs: List[str]
306
    # The meta info
Lianmin Zheng's avatar
Lianmin Zheng committed
307
    meta_info: List[Dict]
308
    # The finish reason
309
    finished_reason: List[BaseFinishReason]
Liangsheng Yin's avatar
Liangsheng Yin committed
310
311


312
313
@dataclass
class BatchEmbeddingOut:
314
    # The request id
315
    rids: List[str]
316
    # The output embedding
317
    embeddings: List[List[float]]
318
    # The meta info
319
    meta_info: List[Dict]
320
    # The finish reason
321
322
323
    finished_reason: List[BaseFinishReason]


Liangsheng Yin's avatar
Liangsheng Yin committed
324
325
326
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
327

328

329
330
331
332
333
334
335
336
337
338
339
340
341
342
@dataclass
class UpdateWeightReqInput:
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
class UpdateWeightReqOutput:
    success: bool
    message: str


343
344
@dataclass
class AbortReq:
345
    # The request id
346
    rid: str