sequence.py 7.75 KB
Newer Older
1
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
2
import enum
Zhuohan Li's avatar
Zhuohan Li committed
3
from typing import Dict, List, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9


class SequenceStatus(enum.Enum):
10
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
11
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
12
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
13
14
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
15
    FINISHED_ABORTED = enum.auto()
Lily Liu's avatar
Lily Liu committed
16
    FINISHED_IGNORED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
17
18
19
20
21
22

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
23
            SequenceStatus.FINISHED_ABORTED,
Lily Liu's avatar
Lily Liu committed
24
            SequenceStatus.FINISHED_IGNORED
Zhuohan Li's avatar
Zhuohan Li committed
25
26
27
28
29
30
31
32
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
33
34
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
35
36
        elif status == SequenceStatus.FINISHED_IGNORED:
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
37
38
39
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
40

41

42
43
44
45
46
47
48
49
class SequenceData:

    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids
        self.output_token_ids: List[int] = []
50
51
        self.cumulative_logprob = 0.0

52
    def append_token_id(self, token_id: int, logprob: float) -> None:
53
54
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
55
56
57
58

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

59
60
61
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

62
63
64
65
66
67
68
69
70
71
72
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
73
74
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
75
76


Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
81
class Sequence:

    def __init__(
        self,
        seq_id: int,
82
        prompt: str,
83
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
87
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
        self.block_size = block_size

90
91
        self.data = SequenceData(prompt_token_ids)
        self.output_logprobs: List[Dict[int, float]] = []
92
        self.output_tokens: List[str] = []
93
        self.output_text = ""
94

Woosuk Kwon's avatar
Woosuk Kwon committed
95
        self.logical_token_blocks: List[LogicalTokenBlock] = []
96
        # Initialize the logical token blocks with the prompt token ids.
97
        self._append_tokens_to_blocks(prompt_token_ids)
98
        self.status = SequenceStatus.WAITING
Woosuk Kwon's avatar
Woosuk Kwon committed
99

100
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
106
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

107
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
        while token_ids:
            if not self.logical_token_blocks:
110
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
114
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
118
            last_block.append_tokens(token_ids[:num_empty_slots])
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
            token_ids = token_ids[num_empty_slots:]

121
122
123
124
125
    def append_token_id(
        self,
        token_id: int,
        logprobs: Dict[int, float],
    ) -> None:
126
        assert token_id in logprobs
127
        self._append_tokens_to_blocks([token_id])
128
        self.output_logprobs.append(logprobs)
129
        self.data.append_token_id(token_id, logprobs[token_id])
130

Woosuk Kwon's avatar
Woosuk Kwon committed
131
    def get_len(self) -> int:
132
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
133

134
135
136
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
137
    def get_token_ids(self) -> List[int]:
138
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
139

140
    def get_last_token_id(self) -> int:
141
        return self.data.get_last_token_id()
142

143
144
145
146
147
148
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

149
150
151
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

152
    def fork(self, child_seq: 'Sequence') -> None:
153
154
        child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
        child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
155
        child_seq.data = copy.deepcopy(self.data)
156
        return None
157

Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
160
161
162
    def __repr__(self) -> str:
        return (f'Sequence(seq_id={self.seq_id}, '
                f'status={self.status.name}, '
                f'num_blocks={len(self.logical_token_blocks)})')

Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
166
167

class SequenceGroup:

    def __init__(
        self,
168
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
169
        seqs: List[Sequence],
170
        sampling_params: SamplingParams,
171
        arrival_time: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
172
    ) -> None:
173
        self.request_id = request_id
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        self.seqs = seqs
175
        self.sampling_params = sampling_params
176
        self.arrival_time = arrival_time
Woosuk Kwon's avatar
Woosuk Kwon committed
177

178
179
180
181
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
Woosuk Kwon's avatar
Woosuk Kwon committed
182
        if status is None:
183
            return self.seqs
Woosuk Kwon's avatar
Woosuk Kwon committed
184
        else:
185
186
187
188
            return [seq for seq in self.seqs if seq.status == status]

    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        return len(self.get_seqs(status))
189
190
191
192
193
194

    def find(self, seq_id: int) -> Sequence:
        for seq in self.seqs:
            if seq.seq_id == seq_id:
                return seq
        raise ValueError(f'Sequence {seq_id} not found.')
Woosuk Kwon's avatar
Woosuk Kwon committed
195

Woosuk Kwon's avatar
Woosuk Kwon committed
196
    def is_finished(self) -> bool:
197
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
198

Woosuk Kwon's avatar
Woosuk Kwon committed
199
    def __repr__(self) -> str:
200
201
202
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
                f"num_seqs={len(self.seqs)})")
203
204


205
class SequenceGroupMetadata:
206
207
208

    def __init__(
        self,
209
        request_id: str,
210
        is_prompt: bool,
211
        seq_data: Dict[int, SequenceData],      # Seq id -> sequence data.
212
        sampling_params: SamplingParams,
213
        block_tables: Dict[int, List[int]],     # Seq id -> list of physical block numbers.
214
    ) -> None:
215
        self.request_id = request_id
216
        self.is_prompt = is_prompt
217
        self.seq_data = seq_data
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        self.sampling_params = sampling_params
        self.block_tables = block_tables


class SequenceOutputs:

    def __init__(
        self,
        seq_id: int,
        parent_seq_id: int,
        output_token: int,
        logprobs: Dict[int, float],         # Token id -> logP(x_i+1 | x_0, ..., x_i).
    ) -> None:
        self.seq_id = seq_id
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
        return (f'SequenceOutputs(seq_id={self.seq_id}, '
                f'parent_seq_id={self.parent_seq_id}, '
                f'output_token={self.output_token}), '
                f'logprobs={self.logprobs}')
Zhuohan Li's avatar
Zhuohan Li committed
241

242
243
244
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, SequenceOutputs):
            return NotImplemented
Zhuohan Li's avatar
Zhuohan Li committed
245
246
247
248
        return (self.seq_id == other.seq_id and
                self.parent_seq_id == other.parent_seq_id and
                self.output_token == other.output_token and
                self.logprobs == other.logprobs)