sequence.py 13.6 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
Zhuohan Li's avatar
Zhuohan Li committed
4
from typing import Dict, List, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
11
PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]]

Woosuk Kwon's avatar
Woosuk Kwon committed
12
13

class SequenceStatus(enum.Enum):
14
    """Status of a sequence."""
15
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
16
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
17
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
18
19
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
20
    FINISHED_ABORTED = enum.auto()
Lily Liu's avatar
Lily Liu committed
21
    FINISHED_IGNORED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
22
23
24
25
26
27

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
28
            SequenceStatus.FINISHED_ABORTED,
29
            SequenceStatus.FINISHED_IGNORED,
Zhuohan Li's avatar
Zhuohan Li committed
30
31
32
33
34
35
36
37
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
38
39
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
40
        elif status == SequenceStatus.FINISHED_IGNORED:
41
42
43
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
44
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
45
46
47
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
48

49

50
class SequenceData:
51
52
53
54
55
56
57
58
59
60
61
    """Data associated with a sequence.


    Args:
        prompt_token_ids: The token IDs of the prompt.

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
62
63
64
65
66
67
68

    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids
        self.output_token_ids: List[int] = []
69
70
        self.cumulative_logprob = 0.0

71
    def append_token_id(self, token_id: int, logprob: float) -> None:
72
73
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
74
75
76
77

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

78
79
80
    def get_prompt_len(self) -> int:
        return len(self.prompt_token_ids)

81
82
83
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

84
85
86
87
88
89
90
91
92
93
94
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
95
96
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
97
98


Woosuk Kwon's avatar
Woosuk Kwon committed
99
class Sequence:
100
101
102
103
104
105
106
107
108
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
        prompt: The prompt of the sequence.
        prompt_token_ids: The token IDs of the prompt.
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
111
112

    def __init__(
        self,
        seq_id: int,
113
        prompt: str,
114
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
118
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
        self.block_size = block_size

121
        self.data = SequenceData(prompt_token_ids)
122
        self.output_logprobs: SampleLogprobs = []
123
        self.output_text = ""
124

Woosuk Kwon's avatar
Woosuk Kwon committed
125
        self.logical_token_blocks: List[LogicalTokenBlock] = []
126
        # Initialize the logical token blocks with the prompt token ids.
127
        self._append_tokens_to_blocks(prompt_token_ids)
128
        self.status = SequenceStatus.WAITING
Woosuk Kwon's avatar
Woosuk Kwon committed
129

130
131
132
133
134
135
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

136
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
141
142
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

143
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
144
145
        cursor = 0
        while cursor < len(token_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
146
            if not self.logical_token_blocks:
147
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
150

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
151
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
155
156
157
            last_block.append_tokens(token_ids[cursor:cursor +
                                               num_empty_slots])
            cursor += num_empty_slots
Woosuk Kwon's avatar
Woosuk Kwon committed
158

159
160
161
162
163
    def append_token_id(
        self,
        token_id: int,
        logprobs: Dict[int, float],
    ) -> None:
164
        assert token_id in logprobs
165
        self._append_tokens_to_blocks([token_id])
166
        self.output_logprobs.append(logprobs)
167
        self.data.append_token_id(token_id, logprobs[token_id])
168

Woosuk Kwon's avatar
Woosuk Kwon committed
169
    def get_len(self) -> int:
170
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
171

172
173
174
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

175
176
177
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
178
    def get_token_ids(self) -> List[int]:
179
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
180

181
    def get_last_token_id(self) -> int:
182
        return self.data.get_last_token_id()
183

184
185
186
187
188
189
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

190
191
192
193
194
195
196
197
198
199
200
201
    def get_beam_search_score(self,
                              length_penalty: float = 0.0,
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
202
            # NOTE: HF implementation does not count the EOS token
203
204
205
206
207
208
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

209
210
211
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

212
213
214
215
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
216

Woosuk Kwon's avatar
Woosuk Kwon committed
217
    def __repr__(self) -> str:
218
219
220
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
                f"num_blocks={len(self.logical_token_blocks)})")
Woosuk Kwon's avatar
Woosuk Kwon committed
221

Woosuk Kwon's avatar
Woosuk Kwon committed
222
223

class SequenceGroup:
224
225
226
227
228
229
230
231
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234

    def __init__(
        self,
235
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
236
        seqs: List[Sequence],
237
        sampling_params: SamplingParams,
238
        arrival_time: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
239
    ) -> None:
240
        self.request_id = request_id
241
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
242
        self.sampling_params = sampling_params
243
        self.arrival_time = arrival_time
244
245
246
247
248
249
250
251
252
253
254
255
256
        self.prompt_logprobs: Optional[PromptLogprobs] = None

    @property
    def prompt(self) -> str:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).data.prompt_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
257

258
259
260
261
262
263
264
265
266
267
268
269
270
271
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
        if self.sampling_params.use_beam_search:
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
            if self.sampling_params.best_of > self.num_seqs():
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
272
273
            # that are not finished yet.
            return self.num_unfinished_seqs()
274

275
276
277
278
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
Woosuk Kwon's avatar
Woosuk Kwon committed
279
        if status is None:
280
            return list(self.seqs_dict.values())
Woosuk Kwon's avatar
Woosuk Kwon committed
281
        else:
282
283
284
285
            return [
                seq for seq in self.seqs_dict.values() if seq.status == status
            ]

286
287
288
289
290
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

291
292
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
293
294
295

    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        return len(self.get_seqs(status))
296

297
298
299
300
301
302
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

303
    def find(self, seq_id: int) -> Sequence:
304
305
306
307
308
309
310
311
312
313
314
315
316
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
317

Woosuk Kwon's avatar
Woosuk Kwon committed
318
    def is_finished(self) -> bool:
319
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
320

Woosuk Kwon's avatar
Woosuk Kwon committed
321
    def __repr__(self) -> str:
322
323
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
324
                f"num_seqs={len(self.seqs_dict)})")
325
326


327
class SequenceGroupMetadata:
328
329
330
331
332
333
334
335
336
337
338
    """Metadata for a sequence group. Used to create `InputMetadata`.


    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
    """
339
340
341

    def __init__(
        self,
342
        request_id: str,
343
        is_prompt: bool,
344
        seq_data: Dict[int, SequenceData],
345
        sampling_params: SamplingParams,
346
        block_tables: Dict[int, List[int]],
347
    ) -> None:
348
        self.request_id = request_id
349
        self.is_prompt = is_prompt
350
        self.seq_data = seq_data
351
352
353
354
355
        self.sampling_params = sampling_params
        self.block_tables = block_tables


class SequenceOutputs:
356
357
358
359
360
361
362
363
364
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
365
366
367
368
369

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
370
        logprobs: Dict[int, float],
371
372
373
374
375
376
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
377
        return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
378
379
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
380

381
382
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, SequenceOutputs):
Zhuohan Li's avatar
Zhuohan Li committed
383
            raise NotImplementedError()
384
        return (self.parent_seq_id == other.parent_seq_id
385
386
                and self.output_token == other.output_token
                and self.logprobs == other.logprobs)
387
388


389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
class SequenceGroupOutputs:
    """The model outputs associated with a sequence group."""

    def __init__(
        self,
        samples: List[SequenceOutputs],
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
        return (f"SequenceGroupOutputs(samples={self.samples}, "
                f"prompt_logprobs={self.prompt_logprobs})")


405
406
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
407
SamplerOutput = List[SequenceGroupOutputs]