sequence.py 14.2 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
Zhuohan Li's avatar
Zhuohan Li committed
4
from typing import Dict, List, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.block import LogicalTokenBlock
7
from vllm.prefix import Prefix
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
12
PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]]

Woosuk Kwon's avatar
Woosuk Kwon committed
13
14

class SequenceStatus(enum.Enum):
15
    """Status of a sequence."""
16
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
17
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
18
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
19
20
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
21
    FINISHED_ABORTED = enum.auto()
Lily Liu's avatar
Lily Liu committed
22
    FINISHED_IGNORED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
23
24
25
26
27
28

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
29
            SequenceStatus.FINISHED_ABORTED,
30
            SequenceStatus.FINISHED_IGNORED,
Zhuohan Li's avatar
Zhuohan Li committed
31
32
33
34
35
36
37
38
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
39
40
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
41
        elif status == SequenceStatus.FINISHED_IGNORED:
42
43
44
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
45
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
46
47
48
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
49

50

51
class SequenceData:
52
53
54
55
56
57
58
59
60
61
62
    """Data associated with a sequence.


    Args:
        prompt_token_ids: The token IDs of the prompt.

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
63
64
65
66
67
68
69

    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids
        self.output_token_ids: List[int] = []
70
71
        self.cumulative_logprob = 0.0

72
    def append_token_id(self, token_id: int, logprob: float) -> None:
73
74
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
75
76
77
78

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

79
80
81
    def get_prompt_len(self) -> int:
        return len(self.prompt_token_ids)

82
83
84
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

85
86
87
88
89
90
91
92
93
94
95
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
96
97
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
98
99


Woosuk Kwon's avatar
Woosuk Kwon committed
100
class Sequence:
101
102
103
104
105
106
107
108
109
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
        prompt: The prompt of the sequence.
        prompt_token_ids: The token IDs of the prompt.
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113

    def __init__(
        self,
        seq_id: int,
114
        prompt: str,
115
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
119
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
        self.block_size = block_size

122
        self.data = SequenceData(prompt_token_ids)
123
        self.output_logprobs: SampleLogprobs = []
124
        self.output_text = ""
125

Woosuk Kwon's avatar
Woosuk Kwon committed
126
        self.logical_token_blocks: List[LogicalTokenBlock] = []
127
        # Initialize the logical token blocks with the prompt token ids.
128
        self._append_tokens_to_blocks(prompt_token_ids)
129
        self.status = SequenceStatus.WAITING
Woosuk Kwon's avatar
Woosuk Kwon committed
130

131
132
133
134
135
136
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

137
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140
141
142
143
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

144
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
145
146
        cursor = 0
        while cursor < len(token_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
147
            if not self.logical_token_blocks:
148
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
152
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
156
157
158
            last_block.append_tokens(token_ids[cursor:cursor +
                                               num_empty_slots])
            cursor += num_empty_slots
Woosuk Kwon's avatar
Woosuk Kwon committed
159

160
161
162
163
164
    def append_token_id(
        self,
        token_id: int,
        logprobs: Dict[int, float],
    ) -> None:
165
        assert token_id in logprobs
166
        self._append_tokens_to_blocks([token_id])
167
        self.output_logprobs.append(logprobs)
168
        self.data.append_token_id(token_id, logprobs[token_id])
169

Woosuk Kwon's avatar
Woosuk Kwon committed
170
    def get_len(self) -> int:
171
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
172

173
174
175
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

176
177
178
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
179
    def get_token_ids(self) -> List[int]:
180
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
181

182
    def get_last_token_id(self) -> int:
183
        return self.data.get_last_token_id()
184

185
186
187
188
189
190
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

191
192
193
194
195
196
197
198
199
200
201
202
    def get_beam_search_score(self,
                              length_penalty: float = 0.0,
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
203
            # NOTE: HF implementation does not count the EOS token
204
205
206
207
208
209
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

210
211
212
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

213
214
215
216
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
217

Woosuk Kwon's avatar
Woosuk Kwon committed
218
    def __repr__(self) -> str:
219
220
221
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
                f"num_blocks={len(self.logical_token_blocks)})")
Woosuk Kwon's avatar
Woosuk Kwon committed
222

Woosuk Kwon's avatar
Woosuk Kwon committed
223
224

class SequenceGroup:
225
226
227
228
229
230
231
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
zspo's avatar
zspo committed
232
        prefix: The prefix of the prompt of the sequence group.
233
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
236

    def __init__(
        self,
237
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
238
        seqs: List[Sequence],
239
        sampling_params: SamplingParams,
240
        arrival_time: float,
241
        prefix: Optional[Prefix] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
242
    ) -> None:
243
        self.request_id = request_id
244
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
245
        self.sampling_params = sampling_params
246
        self.arrival_time = arrival_time
247
        self.prefix: Optional[Prefix] = prefix
248
249
250
251
252
253
254
255
256
257
258
259
260
        self.prompt_logprobs: Optional[PromptLogprobs] = None

    @property
    def prompt(self) -> str:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).data.prompt_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
        if self.sampling_params.use_beam_search:
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
            if self.sampling_params.best_of > self.num_seqs():
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
276
277
            # that are not finished yet.
            return self.num_unfinished_seqs()
278

279
280
281
282
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
Woosuk Kwon's avatar
Woosuk Kwon committed
283
        if status is None:
284
            return list(self.seqs_dict.values())
Woosuk Kwon's avatar
Woosuk Kwon committed
285
        else:
286
287
288
289
            return [
                seq for seq in self.seqs_dict.values() if seq.status == status
            ]

290
291
292
293
294
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

295
296
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
297
298
299

    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        return len(self.get_seqs(status))
300

301
302
303
304
305
306
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

307
    def find(self, seq_id: int) -> Sequence:
308
309
310
311
312
313
314
315
316
317
318
319
320
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
321

Woosuk Kwon's avatar
Woosuk Kwon committed
322
    def is_finished(self) -> bool:
323
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
324

Woosuk Kwon's avatar
Woosuk Kwon committed
325
    def __repr__(self) -> str:
326
327
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
328
                f"num_seqs={len(self.seqs_dict)})")
329
330


331
class SequenceGroupMetadata:
332
333
334
335
336
337
338
339
340
    """Metadata for a sequence group. Used to create `InputMetadata`.

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
341
        prefix: The prefix of the prompt of the sequence group.
342
    """
343
344
345

    def __init__(
        self,
346
        request_id: str,
347
        is_prompt: bool,
348
        seq_data: Dict[int, SequenceData],
349
        sampling_params: SamplingParams,
350
        block_tables: Dict[int, List[int]],
351
        prefix: Optional[Prefix] = None,
352
    ) -> None:
353
        self.request_id = request_id
354
        self.is_prompt = is_prompt
355
        self.seq_data = seq_data
356
357
        self.sampling_params = sampling_params
        self.block_tables = block_tables
358
        self.prefix = prefix
359
360


Zhuohan Li's avatar
Zhuohan Li committed
361
class SequenceOutput:
362
363
364
365
366
367
368
369
370
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
371
372
373
374
375

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
376
        logprobs: Dict[int, float],
377
378
379
380
381
382
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
383
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
384
385
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
386

387
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
388
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
389
            raise NotImplementedError()
390
        return (self.parent_seq_id == other.parent_seq_id
391
392
                and self.output_token == other.output_token
                and self.logprobs == other.logprobs)
393
394


Zhuohan Li's avatar
Zhuohan Li committed
395
396
class SequenceGroupOutput:
    """The model output associated with a sequence group."""
397
398
399

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
400
        samples: List[SequenceOutput],
401
402
403
404
405
406
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
407
        return (f"SequenceGroupOutput(samples={self.samples}, "
408
409
                f"prompt_logprobs={self.prompt_logprobs})")

410
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
411
        if not isinstance(other, SequenceGroupOutput):
412
413
414
415
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

416

Zhuohan Li's avatar
Zhuohan Li committed
417
# For each sequence group, we generate a list of SequenceOutput object,
418
# each of which contains one possible candidate for the next token.
Zhuohan Li's avatar
Zhuohan Li committed
419
SamplerOutput = List[SequenceGroupOutput]