io_struct.py 2.67 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from sglang.srt.sampling_params import SamplingParams


@dataclass
class GenerateReqInput:
    text: Union[List[str], str]
    image_data: Optional[Union[List[str], str]] = None
    sampling_params: Union[List[Dict], Dict] = None
    rid: Optional[Union[List[str], str]] = None
14
15
    return_logprob: Optional[Union[List[bool], bool]] = None
    logprob_start_len: Optional[Union[List[int], int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
20
21
22
23
24
25
    stream: bool = False

    def post_init(self):
        is_single = isinstance(self.text, str)

        if is_single:
            if self.sampling_params is None:
                self.sampling_params = {}
            if self.rid is None:
                self.rid = uuid.uuid4().hex
26
27
28
29
            if self.return_logprob is None:
                self.return_logprob = False
            if self.logprob_start_len is None:
                self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        else:
            num = len(self.text)

            if self.image_data is None:
                self.image_data = [None] * num
            elif not isinstance(self.image_data, list):
                self.image_data = [self.image_data] * num

            if self.sampling_params is None:
                self.sampling_params = [{}] * num
            elif not isinstance(self.sampling_params, list):
                self.sampling_params = [self.sampling_params] * num

            if self.rid is None:
                self.rid = [uuid.uuid4().hex for _ in range(num)]
            else:
                assert isinstance(self.rid, list)

48
49
50
51
            if self.return_logprob is None:
                self.return_logprob = [False] * num
            elif not isinstance(self.return_logprob, list):
                self.return_logprob = [self.return_logprob] * num
Lianmin Zheng's avatar
Lianmin Zheng committed
52

53
54
55
56
            if self.logprob_start_len is None:
                self.logprob_start_len = [0] * num
            elif not isinstance(self.logprob_start_len, list):
                self.logprob_start_len = [self.logprob_start_len] * num
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59
60
61


@dataclass
class TokenizedGenerateReqInput:
    rid: str
Liangsheng Yin's avatar
Liangsheng Yin committed
62
    input_text: str
Lianmin Zheng's avatar
Lianmin Zheng committed
63
64
65
    input_ids: List[int]
    pixel_values: List[float]
    image_hash: int
shiyi.c_98's avatar
shiyi.c_98 committed
66
    image_size: List[int]
Lianmin Zheng's avatar
Lianmin Zheng committed
67
    sampling_params: SamplingParams
68
69
    return_logprob: bool
    logprob_start_len: int
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
73
74
75
76
    stream: bool


@dataclass
class BatchTokenIDOut:
    rids: List[str]
    output_tokens: List[List[int]]
Liangsheng Yin's avatar
Liangsheng Yin committed
77
    output_and_fast_forward_strs: List[str]
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
80
81
82
83
84
85
86
87
88
89
    hit_stop_str: List[Optional[str]]
    skip_special_tokens: List[bool]
    meta_info: List[Dict]
    finished: List[bool]


@dataclass
class BatchStrOut:
    rids: List[str]
    output_str: List[str]
    meta_info: List[Dict]
    finished: List[bool]