io_struct.py 10.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
20
"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""

21
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import uuid
23
from dataclasses import dataclass, field
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25
from typing import Dict, List, Optional, Union

26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
31


@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
32
    # The input prompt. It can be a single prompt or a batch of prompts.
33
    text: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
34
    # The token ids for text; one can either specify text or input_ids.
35
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Ying Sheng's avatar
Ying Sheng committed
36
37
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
38
    image_data: Optional[Union[List[str], str]] = None
39
    # The sampling_params. See descriptions below.
Lianmin Zheng's avatar
Lianmin Zheng committed
40
    sampling_params: Union[List[Dict], Dict] = None
Ying Sheng's avatar
Ying Sheng committed
41
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
43
    # Whether to return logprobs.
44
    return_logprob: Optional[Union[List[bool], bool]] = None
45
    # If return logprobs, the start location in the prompt for returning logprobs.
46
    logprob_start_len: Optional[Union[List[int], int]] = None
47
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
48
    top_logprobs_num: Optional[Union[List[int], int]] = None
49
    # Whether to detokenize tokens in text in the returned logprobs.
50
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
51
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
54
    stream: bool = False

    def post_init(self):
55
56
57
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
58
            raise ValueError("Either text or input_ids should be provided.")
59

Yineng Zhang's avatar
Yineng Zhang committed
60
61
62
63
        if (
            isinstance(self.sampling_params, dict)
            and self.sampling_params.get("n", 1) != 1
        ):
64
            is_single = False
65
        else:
66
67
68
69
            if self.text is not None:
                is_single = isinstance(self.text, str)
            else:
                is_single = isinstance(self.input_ids[0], int)
70
        self.is_single = is_single
Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
73
74
75
76

        if is_single:
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
77
78
79
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
80
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
81
82
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
83
        else:
84
85
86
87
88
89
90
91
92
93
94
95
96
            parallel_sample_num_list = []
            if isinstance(self.sampling_params, dict):
                parallel_sample_num = self.sampling_params.get("n", 1)
            elif isinstance(self.sampling_params, list):
                for sp in self.sampling_params:
                    parallel_sample_num = sp.get("n", 1)
                    parallel_sample_num_list.append(parallel_sample_num)
                parallel_sample_num = max(parallel_sample_num_list)
                all_equal = all(
                    element == parallel_sample_num
                    for element in parallel_sample_num_list
                )
                if parallel_sample_num > 1 and (not all_equal):
yichuan~'s avatar
yichuan~ committed
97
                    # TODO cope with the case that the parallel_sample_num is different for different samples
98
99
100
101
102
103
                    raise ValueError(
                        "The parallel_sample_num should be the same for all samples in sample params."
                    )
            else:
                parallel_sample_num = 1
            self.parallel_sample_num = parallel_sample_num
104
105
106
107

            if parallel_sample_num != 1:
                # parallel sampling +1 represents the original prefill stage
                num = parallel_sample_num + 1
yichuan~'s avatar
yichuan~ committed
108
109
                if isinstance(self.text, list):
                    # suppot batch operation
110
111
                    self.batch_size = len(self.text)
                    num = num * len(self.text)
yichuan~'s avatar
yichuan~ committed
112
113
114
115
116
                elif isinstance(self.input_ids, list) and isinstance(
                    self.input_ids[0], list
                ):
                    self.batch_size = len(self.input_ids)
                    num = num * len(self.input_ids)
117
118
119
                else:
                    self.batch_size = 1
            else:
yichuan~'s avatar
yichuan~ committed
120
                # support select operation
121
122
                num = len(self.text) if self.text is not None else len(self.input_ids)
                self.batch_size = num
Lianmin Zheng's avatar
Lianmin Zheng committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136

            if self.image_data is None:
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
137
138
                if not isinstance(self.rid, list):
                    raise ValueError("The rid should be a list.")
Lianmin Zheng's avatar
Lianmin Zheng committed
139

140
141
142
143
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
Lianmin Zheng's avatar
Lianmin Zheng committed
144

145
            if self.logprob_start_len is None:
146
                self.logprob_start_len = [-1] * num
147
148
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
Lianmin Zheng's avatar
Lianmin Zheng committed
149

Liangsheng Yin's avatar
Liangsheng Yin committed
150
151
152
153
154
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num

Lianmin Zheng's avatar
Lianmin Zheng committed
155
156
157

@dataclass
class TokenizedGenerateReqInput:
158
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
159
    rid: str
160
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
161
    input_text: str
162
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
163
    input_ids: List[int]
164
    # The pixel values for input images
Lianmin Zheng's avatar
Lianmin Zheng committed
165
    pixel_values: List[float]
166
167
168
169
    # The hash values of input images
    image_hashes: List[int]
    # The image sizes
    image_sizes: List[List[int]]
170
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
171
    sampling_params: SamplingParams
172
    # Whether to return the logprobs
173
    return_logprob: bool
174
    # If return logprobs, the start location in the prompt for returning logprobs.
175
    logprob_start_len: int
176
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
177
    top_logprobs_num: int
178
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
179
180
181
    stream: bool


182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None

    def post_init(self):
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
            raise ValueError("Either text or input_ids should be provided.")

        if self.text is not None:
            is_single = isinstance(self.text, str)
        else:
            is_single = isinstance(self.input_ids[0], int)
        self.is_single = is_single

        if is_single:
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
208
            if self.sampling_params is None:
209
210
                self.sampling_params = {}
            self.sampling_params["max_new_tokens"] = 1
211
212
213
214
215
216
217
218
219
220
        else:
            # support select operation
            self.batch_size = (
                len(self.text) if self.text is not None else len(self.input_ids)
            )
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
                if not isinstance(self.rid, list):
                    raise ValueError("The rid should be a list.")
Ying Sheng's avatar
Ying Sheng committed
221
            if self.sampling_params is None:
222
223
224
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
                self.sampling_params[i]["max_new_tokens"] = 1
225
226
227
228


@dataclass
class TokenizedEmbeddingReqInput:
229
    # The request id
230
    rid: str
231
    # The input text
232
    input_text: str
233
    # The input token ids
234
    input_ids: List[int]
235
    # Dummy sampling params for compatibility
236
237
238
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
239
240
@dataclass
class BatchTokenIDOut:
241
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
242
    rids: List[str]
243
    # The version id to sync decode status with in detokenizer_manager
244
    vids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
245
    decoded_texts: List[str]
246
247
    decode_ids: List[int]
    read_offsets: List[int]
Lianmin Zheng's avatar
Lianmin Zheng committed
248
    skip_special_tokens: List[bool]
249
    spaces_between_special_tokens: List[bool]
Lianmin Zheng's avatar
Lianmin Zheng committed
250
    meta_info: List[Dict]
251
    finished_reason: List[BaseFinishReason]
Lianmin Zheng's avatar
Lianmin Zheng committed
252

253
254
255
256
    def __post_init__(self):
        # deepcopy meta_info to avoid modification in place
        self.meta_info = copy.deepcopy(self.meta_info)

Liangsheng Yin's avatar
Liangsheng Yin committed
257

Lianmin Zheng's avatar
Lianmin Zheng committed
258
259
@dataclass
class BatchStrOut:
260
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
261
    rids: List[str]
262
    # The output decoded strings
263
    output_strs: List[str]
264
    # The meta info
Lianmin Zheng's avatar
Lianmin Zheng committed
265
    meta_info: List[Dict]
266
    # The finish reason
267
    finished_reason: List[BaseFinishReason]
Liangsheng Yin's avatar
Liangsheng Yin committed
268
269


270
271
@dataclass
class BatchEmbeddingOut:
272
    # The request id
273
    rids: List[str]
274
    # The output embedding
275
    embeddings: List[List[float]]
276
    # The meta info
277
    meta_info: List[Dict]
278
    # The finish reason
279
280
281
    finished_reason: List[BaseFinishReason]


Liangsheng Yin's avatar
Liangsheng Yin committed
282
283
284
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
285

286

287
288
289
290
291
292
293
294
295
296
297
298
299
300
@dataclass
class UpdateWeightReqInput:
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
class UpdateWeightReqOutput:
    success: bool
    message: str


301
302
@dataclass
class AbortReq:
303
    # The request id
304
    rid: str