io_struct.py 15.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
16
17
18
"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
19
import uuid
20
from dataclasses import dataclass
21
from enum import Enum
22
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23

24
25
import torch

26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29


30
31
32
33
34
35
36
37
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
40
    # The input prompt. It can be a single prompt or a batch of prompts.
41
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
42
    # The token ids for text; one can specify either text or input_ids
43
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
44
45
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
Ying Sheng's avatar
Ying Sheng committed
46
47
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    image_data: Optional[Union[List[str], str]] = None
49
    # The sampling_params. See descriptions below.
50
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
51
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
53
    # Whether to return logprobs.
54
    return_logprob: Optional[Union[List[bool], bool]] = None
55
    # If return logprobs, the start location in the prompt for returning logprobs.
56
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
57
    logprob_start_len: Optional[Union[List[int], int]] = None
58
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
59
    top_logprobs_num: Optional[Union[List[int], int]] = None
60
    # Whether to detokenize tokens in text in the returned logprobs.
61
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
62
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
63
    stream: bool = False
64
65
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
66
67
68
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

69
70
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
71

72
    def normalize_batch_and_arguments(self):
Rin Intachuen's avatar
Rin Intachuen committed
73
74
75
76
77
78
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
79
        ):
Rin Intachuen's avatar
Rin Intachuen committed
80
81
82
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
83

84
        # Derive the batch size
85
86
87
88
89
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
90
                self.is_single = False
91
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
92
93
            self.input_embeds = None
        elif self.input_ids is not None:
94
95
96
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
97
            else:
98
                self.is_single = False
99
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
100
101
102
103
104
105
106
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
                self.batch_size = len(self.input_embeds)
107

108
109
        # Handle parallel sampling
        # When parallel sampling is used, we always treat the input as a batch.
110
111
        if self.sampling_params is None:
            self.parallel_sample_num = 1
112
        elif isinstance(self.sampling_params, dict):
113
114
115
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
Chayenne's avatar
Chayenne committed
116
117
118
119
            assert all(
                self.parallel_sample_num == sampling_params.get("n", 1)
                for sampling_params in self.sampling_params
            ), "The parallel_sample_num should be the same for all samples in sample params."
Lianmin Zheng's avatar
Lianmin Zheng committed
120

121
122
123
124
125
126
127
128
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

        # Fill in default arguments
129
        if self.is_single:
Lianmin Zheng's avatar
Lianmin Zheng committed
130
131
132
133
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
134
135
136
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
137
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
138
139
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
140
        else:
141
142
            if self.parallel_sample_num == 1:
                num = self.batch_size
143
            else:
144
145
                # Expand parallel_sample_num
                num = self.batch_size * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147
148
149
150

            if self.image_data is None:
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num
151
            elif isinstance(self.image_data, list):
152
                pass
Lianmin Zheng's avatar
Lianmin Zheng committed
153
154
155
156
157
158
159
160
161

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
162
                assert isinstance(self.rid, list), "The rid should be a list."
Lianmin Zheng's avatar
Lianmin Zheng committed
163

164
165
166
167
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
168
169
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
170

171
            if self.logprob_start_len is None:
172
                self.logprob_start_len = [-1] * num
173
174
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
175
176
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
177

Liangsheng Yin's avatar
Liangsheng Yin committed
178
179
180
181
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num
182
183
            else:
                assert self.parallel_sample_num == 1
Liangsheng Yin's avatar
Liangsheng Yin committed
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
205
206
207

@dataclass
class TokenizedGenerateReqInput:
208
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
209
    rid: str
210
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
211
    input_text: str
212
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
213
    input_ids: List[int]
214
    # The image inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
215
    image_inputs: dict
216
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
217
    sampling_params: SamplingParams
218
    # Whether to return the logprobs
219
    return_logprob: bool
220
    # If return logprobs, the start location in the prompt for returning logprobs.
221
    logprob_start_len: int
222
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
223
    top_logprobs_num: int
224
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226
    stream: bool

227
228
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
229
230
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
231

232
233
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
234

Lianmin Zheng's avatar
Lianmin Zheng committed
235

236
237
238
239
240
241
242
243
244
245
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
246
247
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
248

249
    def normalize_batch_and_arguments(self):
250
251
252
253
254
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
            raise ValueError("Either text or input_ids should be provided.")

255
        # Derive the batch size
256
        if self.text is not None:
257
258
259
260
261
262
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.text)
263
        else:
264
265
266
267
268
269
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.input_ids)
270

271
        # Fill in default arguments
272
        if self.is_single:
273
274
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
275
            if self.sampling_params is None:
276
                self.sampling_params = {}
277
            self.sampling_params["max_new_tokens"] = 0
278
279
280
281
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
282
283
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
284
            if self.sampling_params is None:
285
286
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
287
                self.sampling_params[i]["max_new_tokens"] = 0
288

289
290
291
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
292

293
294
295
296
297
298
299
    def __getitem__(self, i):
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
300
301
302


@dataclass
303
class TokenizedEmbeddingReqInput:
304
305
306
307
308
309
310
311
312
313
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
@dataclass
class BatchTokenIDOut:
316
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
317
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
318
319
320
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
321
    # The version id to sync decode status with in detokenizer_manager
322
    vids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
323
    decoded_texts: List[str]
324
325
    decode_ids: List[int]
    read_offsets: List[int]
326
    # Only used when `--skip-tokenizer-init` is on
327
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
328
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
329
    skip_special_tokens: List[bool]
330
    spaces_between_special_tokens: List[bool]
331
    no_stop_trim: List[bool]
Lianmin Zheng's avatar
Lianmin Zheng committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
    normalized_prompt_logprob: List[float]
Lianmin Zheng's avatar
Lianmin Zheng committed
346

Liangsheng Yin's avatar
Liangsheng Yin committed
347

Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
@dataclass
class BatchStrOut:
350
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
351
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
352
353
    # The finish reason
    finished_reasons: List[dict]
354
    # The output decoded strings
355
    output_strs: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
356
357
358
359
360

    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
361

Lianmin Zheng's avatar
Lianmin Zheng committed
362
363
364
365
366
367
368
369
370
371
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
    normalized_prompt_logprob: List[float]
Liangsheng Yin's avatar
Liangsheng Yin committed
372
373


374
375
@dataclass
class BatchEmbeddingOut:
376
    # The request id
377
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
378
379
    # The finish reason
    finished_reasons: List[BaseFinishReason]
380
    # The output embedding
381
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
    # Token counts
    prompt_tokens: List[int]
384
385


Liangsheng Yin's avatar
Liangsheng Yin committed
386
387
388
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
389

390

391
@dataclass
Chayenne's avatar
Chayenne committed
392
class UpdateWeightFromDiskReqInput:
393
394
395
396
397
398
399
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
Chayenne's avatar
Chayenne committed
400
class UpdateWeightFromDiskReqOutput:
401
402
403
404
    success: bool
    message: str


405
406
407
408
409
410
411
412
413
414
415
416
417
@dataclass
class UpdateWeightsFromDistributedReqInput:
    name: str
    dtype: str
    shape: List[int]


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


418
419
@dataclass
class UpdateWeightsFromTensorReqInput:
420
    serialized_named_tensors: bytes  # indeed Dict[str, torch.Tensor]
421
422
423
424
425
426
427
428


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


451
452
453
454
455
456
457
458
459
460
461
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


462
463
@dataclass
class AbortReq:
464
    # The request id
465
    rid: str
466
467
468
469
470


class ProfileReq(Enum):
    START_PROFILE = 1
    STOP_PROFILE = 2
471
472


473
474
475
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
476
    session_id: Optional[str] = None
477
478
479
480
481
482
483
484
485


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
486
487
    session_id: Optional[str]
    success: bool