io_struct.py 15.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
16
17
18
"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
19
import uuid
20
from dataclasses import dataclass
21
from enum import Enum
22
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23

24
25
import torch

26
from sglang.srt.managers.schedule_batch import BaseFinishReason
27
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29


30
31
32
33
34
35
36
37
@dataclass
class SessionParams:
    id: Optional[str] = None
    rid: Optional[str] = None
    offset: Optional[int] = None
    replace: Optional[bool] = None


Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
40
    # The input prompt. It can be a single prompt or a batch of prompts.
41
    text: Optional[Union[List[str], str]] = None
Rin Intachuen's avatar
Rin Intachuen committed
42
    # The token ids for text; one can specify either text or input_ids
43
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Rin Intachuen's avatar
Rin Intachuen committed
44
45
    # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
Ying Sheng's avatar
Ying Sheng committed
46
47
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    image_data: Optional[Union[List[str], str]] = None
49
    # The sampling_params. See descriptions below.
50
    sampling_params: Optional[Union[List[Dict], Dict]] = None
Ying Sheng's avatar
Ying Sheng committed
51
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
53
    # Whether to return logprobs.
54
    return_logprob: Optional[Union[List[bool], bool]] = None
55
    # If return logprobs, the start location in the prompt for returning logprobs.
56
    # By default, this value is "-1", which means it will only return logprobs for output tokens.
57
    logprob_start_len: Optional[Union[List[int], int]] = None
58
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
59
    top_logprobs_num: Optional[Union[List[int], int]] = None
60
    # Whether to detokenize tokens in text in the returned logprobs.
61
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
62
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
63
    stream: bool = False
64
65
    # The modalities of the image data [image, multi-images, video]
    modalities: Optional[List[str]] = None
66
67
68
    # LoRA related
    lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

69
70
    # Session info for continual prompting
    session_params: Optional[Union[List[Dict], Dict]] = None
71

72
    def normalize_batch_and_arguments(self):
Rin Intachuen's avatar
Rin Intachuen committed
73
74
75
76
77
78
        if (
            self.text is None and self.input_ids is None and self.input_embeds is None
        ) or (
            self.text is not None
            and self.input_ids is not None
            and self.input_embeds is not None
79
        ):
Rin Intachuen's avatar
Rin Intachuen committed
80
81
82
            raise ValueError(
                "Either text, input_ids or input_embeds should be provided."
            )
83

84
        # Derive the batch size
85
86
87
88
89
        if self.text is not None:
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
90
                self.is_single = False
91
                self.batch_size = len(self.text)
Rin Intachuen's avatar
Rin Intachuen committed
92
93
            self.input_embeds = None
        elif self.input_ids is not None:
94
95
96
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
97
            else:
98
                self.is_single = False
99
                self.batch_size = len(self.input_ids)
Rin Intachuen's avatar
Rin Intachuen committed
100
101
102
103
104
105
106
            self.input_embeds = None
        else:
            if isinstance(self.input_embeds[0][0], float):
                self.is_single = True
                self.batch_size = 1
            else:
                self.batch_size = len(self.input_embeds)
107

108
109
        # Handle parallel sampling
        # When parallel sampling is used, we always treat the input as a batch.
110
111
        if self.sampling_params is None:
            self.parallel_sample_num = 1
112
        elif isinstance(self.sampling_params, dict):
113
114
115
            self.parallel_sample_num = self.sampling_params.get("n", 1)
        else:  # isinstance(self.sampling_params, list):
            self.parallel_sample_num = self.sampling_params[0].get("n", 1)
Chayenne's avatar
Chayenne committed
116
117
118
119
            assert all(
                self.parallel_sample_num == sampling_params.get("n", 1)
                for sampling_params in self.sampling_params
            ), "The parallel_sample_num should be the same for all samples in sample params."
Lianmin Zheng's avatar
Lianmin Zheng committed
120

121
122
123
124
125
126
127
128
        if self.parallel_sample_num > 1 and self.is_single:
            self.is_single = False
            if self.text is not None:
                self.text = [self.text]
            if self.input_ids is not None:
                self.input_ids = [self.input_ids]

        # Fill in default arguments
129
        if self.is_single:
Lianmin Zheng's avatar
Lianmin Zheng committed
130
131
132
133
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
134
135
136
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
137
                self.logprob_start_len = -1
Liangsheng Yin's avatar
Liangsheng Yin committed
138
139
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
140
        else:
141
142
            if self.parallel_sample_num == 1:
                num = self.batch_size
143
            else:
144
145
                # Expand parallel_sample_num
                num = self.batch_size * self.parallel_sample_num
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147
148
149
150

            if self.image_data is None:
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num
151
            elif isinstance(self.image_data, list):
152
                pass
Lianmin Zheng's avatar
Lianmin Zheng committed
153
154
155
156
157
158
159
160
161

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
162
                assert isinstance(self.rid, list), "The rid should be a list."
Lianmin Zheng's avatar
Lianmin Zheng committed
163

164
165
166
167
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
168
169
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
170

171
            if self.logprob_start_len is None:
172
                self.logprob_start_len = [-1] * num
173
174
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
175
176
            else:
                assert self.parallel_sample_num == 1
Lianmin Zheng's avatar
Lianmin Zheng committed
177

Liangsheng Yin's avatar
Liangsheng Yin committed
178
179
180
181
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num
182
183
            else:
                assert self.parallel_sample_num == 1
Liangsheng Yin's avatar
Liangsheng Yin committed
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid

    def __getitem__(self, i):
        return GenerateReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            image_data=self.image_data[i],
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
            return_logprob=self.return_logprob[i],
            logprob_start_len=self.logprob_start_len[i],
            top_logprobs_num=self.top_logprobs_num[i],
            return_text_in_logprobs=self.return_text_in_logprobs,
            stream=self.stream,
            modalities=self.modalities[i] if self.modalities else None,
            lora_path=self.lora_path[i] if self.lora_path is not None else None,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
205
206
207

@dataclass
class TokenizedGenerateReqInput:
208
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
209
    rid: str
210
    # The input text
Liangsheng Yin's avatar
Liangsheng Yin committed
211
    input_text: str
212
    # The input token ids
Lianmin Zheng's avatar
Lianmin Zheng committed
213
    input_ids: List[int]
214
    # The image inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
215
    image_inputs: dict
216
    # The sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
217
    sampling_params: SamplingParams
218
    # Whether to return the logprobs
219
    return_logprob: bool
220
    # If return logprobs, the start location in the prompt for returning logprobs.
221
    logprob_start_len: int
222
    # If return logprobs, the number of top logprobs to return at each position.
Liangsheng Yin's avatar
Liangsheng Yin committed
223
    top_logprobs_num: int
224
    # Whether to stream output
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226
    stream: bool

227
228
    # LoRA related
    lora_path: Optional[str] = None  # None means just use the base model
Rin Intachuen's avatar
Rin Intachuen committed
229
230
    # The input embeds
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
231

232
233
    # Session info for continual prompting
    session_params: Optional[SessionParams] = None
234

Lianmin Zheng's avatar
Lianmin Zheng committed
235

236
237
238
239
240
241
242
243
244
245
@dataclass
class EmbeddingReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Dummy sampling params for compatibility
    sampling_params: Union[List[Dict], Dict] = None
Rin Intachuen's avatar
Rin Intachuen committed
246
247
    # Dummy input embeds for compatibility
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
248

249
    def normalize_batch_and_arguments(self):
250
251
252
253
254
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
            raise ValueError("Either text or input_ids should be provided.")

255
        # Derive the batch size
256
        if self.text is not None:
257
258
259
260
261
262
            if isinstance(self.text, str):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.text)
263
        else:
264
265
266
267
268
269
            if isinstance(self.input_ids[0], int):
                self.is_single = True
                self.batch_size = 1
            else:
                self.is_single = False
                self.batch_size = len(self.input_ids)
270

271
        # Fill in default arguments
272
        if self.is_single:
273
274
            if self.rid is None:
                self.rid = uuid.uuid4().hex
Ying Sheng's avatar
Ying Sheng committed
275
            if self.sampling_params is None:
276
                self.sampling_params = {}
277
            self.sampling_params["max_new_tokens"] = 0
278
279
280
281
        else:
            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
            else:
282
283
                assert isinstance(self.rid, list), "The rid should be a list."

Ying Sheng's avatar
Ying Sheng committed
284
            if self.sampling_params is None:
285
286
                self.sampling_params = [{}] * self.batch_size
            for i in range(self.batch_size):
287
                self.sampling_params[i]["max_new_tokens"] = 0
288

289
290
291
    def regenerate_rid(self):
        self.rid = uuid.uuid4().hex
        return self.rid
292

293
294
295
296
297
298
299
    def __getitem__(self, i):
        return EmbeddingReqInput(
            text=self.text[i] if self.text is not None else None,
            input_ids=self.input_ids[i] if self.input_ids is not None else None,
            sampling_params=self.sampling_params[i],
            rid=self.rid[i],
        )
300
301
302


@dataclass
303
class TokenizedEmbeddingReqInput:
304
305
306
307
308
309
310
311
312
313
    # The request id
    rid: str
    # The input text
    input_text: str
    # The input token ids
    input_ids: List[int]
    # Dummy sampling params for compatibility
    sampling_params: SamplingParams


Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
@dataclass
class BatchTokenIDOut:
316
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
317
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
318
319
320
    # The finish reason
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
321
    # The version id to sync decode status with in detokenizer_manager
322
    vids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
323
    decoded_texts: List[str]
324
325
    decode_ids: List[int]
    read_offsets: List[int]
326
327
328
    # Only used when --return-token-ids` is set
    origin_input_ids: Optional[List[int]]
    # Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
329
    output_ids: Optional[List[int]]
Lianmin Zheng's avatar
Lianmin Zheng committed
330
    # Detokenization configs
Lianmin Zheng's avatar
Lianmin Zheng committed
331
    skip_special_tokens: List[bool]
332
    spaces_between_special_tokens: List[bool]
333
    no_stop_trim: List[bool]
Lianmin Zheng's avatar
Lianmin Zheng committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
    normalized_prompt_logprob: List[float]
Lianmin Zheng's avatar
Lianmin Zheng committed
348

Liangsheng Yin's avatar
Liangsheng Yin committed
349

Lianmin Zheng's avatar
Lianmin Zheng committed
350
351
@dataclass
class BatchStrOut:
352
    # The request id
Lianmin Zheng's avatar
Lianmin Zheng committed
353
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
354
355
    # The finish reason
    finished_reasons: List[dict]
356
    # The output decoded strings
357
    output_strs: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
358

359
360
361
362
    # The token ids
    origin_input_ids: Optional[List[int]]
    output_ids: Optional[List[int]]

Lianmin Zheng's avatar
Lianmin Zheng committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    # Token counts
    prompt_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]
    # Logprobs
    input_token_logprobs_val: List[float]
    input_token_logprobs_idx: List[int]
    output_token_logprobs_val: List[float]
    output_token_logprobs_idx: List[int]
    input_top_logprobs_val: List[List]
    input_top_logprobs_idx: List[List]
    output_top_logprobs_val: List[List]
    output_top_logprobs_idx: List[List]
    normalized_prompt_logprob: List[float]
Liangsheng Yin's avatar
Liangsheng Yin committed
377
378


379
380
@dataclass
class BatchEmbeddingOut:
381
    # The request id
382
    rids: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
383
384
    # The finish reason
    finished_reasons: List[BaseFinishReason]
385
    # The output embedding
386
    embeddings: List[List[float]]
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
    # Token counts
    prompt_tokens: List[int]
389
390


Liangsheng Yin's avatar
Liangsheng Yin committed
391
392
393
@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
394

395

396
@dataclass
Chayenne's avatar
Chayenne committed
397
class UpdateWeightFromDiskReqInput:
398
399
400
401
402
403
404
    # The model path with the new weights
    model_path: str
    # The format to load the weights
    load_format: Optional[str] = None


@dataclass
Chayenne's avatar
Chayenne committed
405
class UpdateWeightFromDiskReqOutput:
406
407
408
409
    success: bool
    message: str


410
411
412
413
414
415
416
417
418
419
420
421
422
@dataclass
class UpdateWeightsFromDistributedReqInput:
    name: str
    dtype: str
    shape: List[int]


@dataclass
class UpdateWeightsFromDistributedReqOutput:
    success: bool
    message: str


423
424
425
426
427
428
429
430
431
432
433
434
@dataclass
class UpdateWeightsFromTensorReqInput:
    name: str
    tensor: torch.Tensor


@dataclass
class UpdateWeightsFromTensorReqOutput:
    success: bool
    message: str


435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
@dataclass
class InitWeightsUpdateGroupReqInput:
    # The master address
    master_address: str
    # The master port
    master_port: int
    # The rank offset
    rank_offset: int
    # The world size
    world_size: int
    # The group name
    group_name: str = "weight_update_group"
    # The backend
    backend: str = "nccl"


@dataclass
class InitWeightsUpdateGroupReqOutput:
    success: bool
    message: str


457
458
459
460
461
462
463
464
465
466
467
@dataclass
class GetWeightsByNameReqInput:
    name: str
    truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
    parameter: list


468
469
@dataclass
class AbortReq:
470
    # The request id
471
    rid: str
472
473
474
475
476


class ProfileReq(Enum):
    START_PROFILE = 1
    STOP_PROFILE = 2
477
478


479
480
481
@dataclass
class OpenSessionReqInput:
    capacity_of_str_len: int
482
    session_id: Optional[str] = None
483
484
485
486
487
488
489
490
491


@dataclass
class CloseSessionReqInput:
    session_id: str


@dataclass
class OpenSessionReqOutput:
492
493
    session_id: Optional[str]
    success: bool