io_struct.py 15.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
16
17
18
"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
19
import uuid
20
from dataclasses import dataclass
21
from enum import Enum
22
from typing import Dict, List, Optional, Union
23

24
from sglang.srt.managers.schedule_batch import BaseFinishReason
25
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27


28
29
30
31
32
33
34
35
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
38
    # The input prompt. It can be a single prompt or a batch of prompts.
39
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
40
    # The token ids for text; one can specify either text or input_ids
41
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
42
43
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
Ying Sheng's avatar
Ying Sheng committed
44
45
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    image_data: Optional[Union[List[str], str]] = None
47
    # The sampling_params. See descriptions below.
48
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
49
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
51
    # Whether to return logprobs.
52
    return_logprob: Optional[Union[List[bool], bool]] = None
53
    # If return logprobs, the start location in the prompt for returning logprobs.
54
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
55
    logprob_start_len: Optional[Union[List[int], int]] = None
56
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
57
    top_logprobs_num: Optional[Union[List[int], int]] = None
58
    # Whether to detokenize tokens in text in the returned logprobs.
59
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
60
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
61
    stream: bool = False
62
63
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
64
65
66
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

67
68
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
69

70
    def normalize_batch_and_arguments(self):
Rin Intachuen's avatar
Rin Intachuen committed
71
72
73
74
75
76
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
77
        ):
Rin Intachuen's avatar
Rin Intachuen committed
78
79
80
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
81

82
        # Derive the batch size
83
84
85
86
87
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
88
                self.is_single = False
89
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
90
91
            self.input_embeds = None
        elif self.input_ids is not None:
92
93
94
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
95
            else:
96
                self.is_single = False
97
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
98
99
100
101
102
103
104
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
                self.batch_size = len(self.input_embeds)
105

106
107
        # Handle parallel sampling
        # When parallel sampling is used, we always treat the input as a batch.
108
109
        if self.sampling_params is None:
            self.parallel_sample_num = 1
110
        elif isinstance(self.sampling_params, dict):
111
112
113
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
Chayenne's avatar
Chayenne committed
114
115
116
117
            assert all(
                self.parallel_sample_num == sampling_params.get("n", 1)
                for sampling_params in self.sampling_params
            ), "The parallel_sample_num should be the same for all samples in sample params."
Lianmin Zheng's avatar
Lianmin Zheng committed
118

119
120
121
122
123
124
125
126
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

        # Fill in default arguments
127
        if self.is_single:
Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
130
131
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
132
133
134
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
135
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
136
137
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
138
        else:
139
140
            if self.parallel_sample_num == 1:
                num = self.batch_size
141
            else:
142
143
                # Expand parallel_sample_num
                num = self.batch_size * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
146
147
148

            if self.image_data is None:
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num
149
            elif isinstance(self.image_data, list):
150
                pass
Lianmin Zheng's avatar
Lianmin Zheng committed
151
152
153
154
155
156
157
158
159

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
160
                assert isinstance(self.rid, list), "The rid should be a list."
Lianmin Zheng's avatar
Lianmin Zheng committed
161

162
163
164
165
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
166
167
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
168

169
            if self.logprob_start_len is None:
170
                self.logprob_start_len = [-1] * num
171
172
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
173
174
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
175

Liangsheng Yin's avatar
Liangsheng Yin committed
176
177
178
179
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num
180
181
            else:
                assert self.parallel_sample_num == 1
Liangsheng Yin's avatar
Liangsheng Yin committed
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
203
204
205

@dataclass
class TokenizedGenerateReqInput:
206
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
207
    rid: str
208
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
209
    input_text: str
210
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
211
    input_ids: List[int]
212
    # The image inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
213
    image_inputs: dict
214
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
215
    sampling_params: SamplingParams
216
    # Whether to return the logprobs
217
    return_logprob: bool
218
    # If return logprobs, the start location in the prompt for returning logprobs.
219
    logprob_start_len: int
220
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
221
    top_logprobs_num: int
222
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
    stream: bool

225
226
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
227
228
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
229

230
231
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
232

Lianmin Zheng's avatar
Lianmin Zheng committed
233

234
235
236
237
238
239
240
241
242
243
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
244
245
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
246

247
    def normalize_batch_and_arguments(self):
248
249
250
251
252
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
            raise ValueError("Either text or input_ids should be provided.")

253
        # Derive the batch size
254
        if self.text is not None:
255
256
257
258
259
260
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.text)
261
        else:
262
263
264
265
266
267
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.input_ids)
268

269
        # Fill in default arguments
270
        if self.is_single:
271
272
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
273
            if self.sampling_params is None:
274
                self.sampling_params = {}
275
            self.sampling_params["max_new_tokens"] = 0
276
277
278
279
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
280
281
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
282
            if self.sampling_params is None:
283
284
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
285
                self.sampling_params[i]["max_new_tokens"] = 0
286

287
288
289
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
290

291
292
293
294
295
296
297
    def __getitem__(self, i):
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
298
299
300


@dataclass
301
class TokenizedEmbeddingReqInput:
302
303
304
305
306
307
308
309
310
311
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
312
313
@dataclass
class BatchTokenIDOut:
314
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
315
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
316
317
318
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
319
    # The version id to sync decode status with in detokenizer_manager
320
    vids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
321
    decoded_texts: List[str]
322
323
    decode_ids: List[int]
    read_offsets: List[int]
324
    # Only used when `--skip-tokenizer-init` is on
325
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
326
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
327
    skip_special_tokens: List[bool]
328
    spaces_between_special_tokens: List[bool]
329
    no_stop_trim: List[bool]
Lianmin Zheng's avatar
Lianmin Zheng committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
    normalized_prompt_logprob: List[float]
Lianmin Zheng's avatar
Lianmin Zheng committed
344

Liangsheng Yin's avatar
Liangsheng Yin committed
345

Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
@dataclass
class BatchStrOut:
348
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
349
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
350
351
    # The finish reason
    finished_reasons: List[dict]
352
    # The output decoded strings
353
    output_strs: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
354
355
356
357
358

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
359

Lianmin Zheng's avatar
Lianmin Zheng committed
360
361
362
363
364
365
366
367
368
369
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
    normalized_prompt_logprob: List[float]
Liangsheng Yin's avatar
Liangsheng Yin committed
370
371


372
373
@dataclass
class BatchEmbeddingOut:
374
    # The request id
375
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
376
377
    # The finish reason
    finished_reasons: List[BaseFinishReason]
378
    # The output embedding
379
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
    # Token counts
    prompt_tokens: List[int]
382
383


Liangsheng Yin's avatar
Liangsheng Yin committed
384
385
386
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
387

388

389
@dataclass
Chayenne's avatar
Chayenne committed
390
class UpdateWeightFromDiskReqInput:
391
392
393
394
395
396
397
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
Chayenne's avatar
Chayenne committed
398
class UpdateWeightFromDiskReqOutput:
399
400
401
402
    success: bool
    message: str


403
404
405
406
407
408
409
410
411
412
413
414
415
@dataclass
class UpdateWeightsFromDistributedReqInput:
    name: str
    dtype: str
    shape: List[int]


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


416
417
@dataclass
class UpdateWeightsFromTensorReqInput:
418
    serialized_named_tensors: bytes  # indeed Dict[str, torch.Tensor]
419
420
421
422
423
424
425
426


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


449
450
451
452
453
454
455
456
457
458
459
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
@dataclass
class ReleaseMemoryOccupationReqInput:
    pass


@dataclass
class ReleaseMemoryOccupationReqOutput:
    pass


@dataclass
class ResumeMemoryOccupationReqInput:
    pass


@dataclass
class ResumeMemoryOccupationReqOutput:
    pass


480
481
@dataclass
class AbortReq:
482
    # The request id
483
    rid: str
484
485
486
487
488


class ProfileReq(Enum):
    START_PROFILE = 1
    STOP_PROFILE = 2
489
490


491
492
493
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
494
    session_id: Optional[str] = None
495
496
497
498
499
500
501
502
503


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
504
505
    session_id: Optional[str]
    success: bool