io_struct.py 5.35 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
9
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

10
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
11
from sglang.srt.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
12
13
14
15


@dataclass
class GenerateReqInput:
Ying Sheng's avatar
Ying Sheng committed
16
    # The input prompt. It can be a single prompt or a batch of prompts.
17
    text: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
18
    # The token ids for text; one can either specify text or input_ids.
19
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
Ying Sheng's avatar
Ying Sheng committed
20
21
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
Lianmin Zheng's avatar
Lianmin Zheng committed
22
    image_data: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
23
    # The sampling_params.
Lianmin Zheng's avatar
Lianmin Zheng committed
24
    sampling_params: Union[List[Dict], Dict] = None
Ying Sheng's avatar
Ying Sheng committed
25
    # The request id.
Lianmin Zheng's avatar
Lianmin Zheng committed
26
    rid: Optional[Union[List[str], str]] = None
Ying Sheng's avatar
Ying Sheng committed
27
    # Whether to return logprobs.
28
    return_logprob: Optional[Union[List[bool], bool]] = None
Ying Sheng's avatar
Ying Sheng committed
29
    # The start location of the prompt for return_logprob.
30
    logprob_start_len: Optional[Union[List[int], int]] = None
Ying Sheng's avatar
Ying Sheng committed
31
    # The number of top logprobs to return.
Liangsheng Yin's avatar
Liangsheng Yin committed
32
    top_logprobs_num: Optional[Union[List[int], int]] = None
Ying Sheng's avatar
Ying Sheng committed
33
    # Whether to detokenize tokens in logprobs.
34
    return_text_in_logprobs: bool = False
Ying Sheng's avatar
Ying Sheng committed
35
    # Whether to stream output.
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
    stream: bool = False

    def post_init(self):
39
40
41
        if (self.text is None and self.input_ids is None) or (
            self.text is not None and self.input_ids is not None
        ):
42
            raise ValueError("Either text or input_ids should be provided.")
Yineng Zhang's avatar
Yineng Zhang committed
43
44
45
46
        if (
            isinstance(self.sampling_params, dict)
            and self.sampling_params.get("n", 1) != 1
        ):
47
            is_single = False
48
        else:
49
50
51
52
            if self.text is not None:
                is_single = isinstance(self.text, str)
            else:
                is_single = isinstance(self.input_ids[0], int)
53
        self.is_single = is_single
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
59

        if is_single:
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
60
61
62
63
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
                self.logprob_start_len = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
            if self.top_logprobs_num is None:
                self.top_logprobs_num = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
66
        else:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

            parallel_sample_num = self.sampling_params.get("n", 1)

            if parallel_sample_num != 1:
                # parallel sampling +1 represents the original prefill stage
                num = parallel_sample_num + 1
                if isinstance(self.text, List):
                    ## suppot batch operation
                    self.batch_size = len(self.text)
                    num = num * len(self.text)
                else:
                    self.batch_size = 1
            else:
                ## support select operation
                num = len(self.text) if self.text is not None else len(self.input_ids)
                self.batch_size = num
Lianmin Zheng's avatar
Lianmin Zheng committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96

            if self.image_data is None:
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
97
98
                if not isinstance(self.rid, list):
                    raise ValueError("The rid should be a list.")
Lianmin Zheng's avatar
Lianmin Zheng committed
99

100
101
102
103
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
Lianmin Zheng's avatar
Lianmin Zheng committed
104

105
106
107
108
            if self.logprob_start_len is None:
                self.logprob_start_len = [0] * num
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
Lianmin Zheng's avatar
Lianmin Zheng committed
109

Liangsheng Yin's avatar
Liangsheng Yin committed
110
111
112
113
114
            if self.top_logprobs_num is None:
                self.top_logprobs_num = [0] * num
            elif not isinstance(self.top_logprobs_num, list):
                self.top_logprobs_num = [self.top_logprobs_num] * num

Lianmin Zheng's avatar
Lianmin Zheng committed
115
116
117
118

@dataclass
class TokenizedGenerateReqInput:
    rid: str
Liangsheng Yin's avatar
Liangsheng Yin committed
119
    input_text: str
Lianmin Zheng's avatar
Lianmin Zheng committed
120
121
122
    input_ids: List[int]
    pixel_values: List[float]
    image_hash: int
shiyi.c_98's avatar
shiyi.c_98 committed
123
    image_size: List[int]
Lianmin Zheng's avatar
Lianmin Zheng committed
124
    sampling_params: SamplingParams
125
126
    return_logprob: bool
    logprob_start_len: int
Liangsheng Yin's avatar
Liangsheng Yin committed
127
    top_logprobs_num: int
Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
130
131
132
133
    stream: bool


@dataclass
class BatchTokenIDOut:
    rids: List[str]
134
    vids: List[int]
Liangsheng Yin's avatar
Liangsheng Yin committed
135
    decoded_texts: List[str]
136
137
    decode_ids: List[int]
    read_offsets: List[int]
Lianmin Zheng's avatar
Lianmin Zheng committed
138
    skip_special_tokens: List[bool]
139
    spaces_between_special_tokens: List[bool]
Lianmin Zheng's avatar
Lianmin Zheng committed
140
    meta_info: List[Dict]
141
    finished_reason: List[BaseFinishReason]
Lianmin Zheng's avatar
Lianmin Zheng committed
142

Liangsheng Yin's avatar
Liangsheng Yin committed
143

Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
146
@dataclass
class BatchStrOut:
    rids: List[str]
147
    output_strs: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
148
    meta_info: List[Dict]
149
    finished_reason: List[BaseFinishReason]
Liangsheng Yin's avatar
Liangsheng Yin committed
150
151
152
153
154


@dataclass
class FlushCacheReq:
    pass
Cody Yu's avatar
Cody Yu committed
155

156

157
158
159
160
161
@dataclass
class AbortReq:
    rid: str


Cody Yu's avatar
Cody Yu committed
162
163
@dataclass
class DetokenizeReqInput:
164
    input_ids: List[int]