sequence.py 17.7 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from dataclasses import dataclass
Zhuohan Li's avatar
Zhuohan Li committed
5
from typing import Dict, List, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams
9
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
16
17
18
19
20

@dataclass
class Logprob:
    """Infos for supporting OpenAI compatible logprobs."""
    logprob: float
    decoded_token: Optional[str] = None


PromptLogprobs = List[Optional[Dict[int, Logprob]]]
SampleLogprobs = List[Dict[int, Logprob]]
21

Woosuk Kwon's avatar
Woosuk Kwon committed
22
23

class SequenceStatus(enum.Enum):
24
    """Status of a sequence."""
25
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
26
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
27
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
28
29
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
30
    FINISHED_ABORTED = enum.auto()
Lily Liu's avatar
Lily Liu committed
31
    FINISHED_IGNORED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
32
33
34
35
36
37

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
38
            SequenceStatus.FINISHED_ABORTED,
39
            SequenceStatus.FINISHED_IGNORED,
Zhuohan Li's avatar
Zhuohan Li committed
40
41
42
43
44
45
46
47
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
48
49
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
50
        elif status == SequenceStatus.FINISHED_IGNORED:
51
52
53
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
54
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
55
56
57
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

    Args:
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


79
class SequenceData:
80
81
82
83
84
85
86
87
88
89
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
90
91
92
93
94
95
96

    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids
        self.output_token_ids: List[int] = []
97
98
        self.cumulative_logprob = 0.0

99
    def append_token_id(self, token_id: int, logprob: float) -> None:
100
101
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
102
103
104
105

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

106
107
108
    def get_prompt_len(self) -> int:
        return len(self.prompt_token_ids)

109
110
111
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

112
113
114
115
116
117
118
119
120
121
122
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
123
124
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
125
126


Woosuk Kwon's avatar
Woosuk Kwon committed
127
class Sequence:
128
129
130
131
132
133
134
135
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
        prompt: The prompt of the sequence.
        prompt_token_ids: The token IDs of the prompt.
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
136
        lora_request: LoRA request.
137
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140
141

    def __init__(
        self,
        seq_id: int,
142
        prompt: str,
143
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        block_size: int,
Cade Daniel's avatar
Cade Daniel committed
145
        eos_token_id: Optional[int] = None,
146
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
    ) -> None:
        self.seq_id = seq_id
149
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
150
        self.block_size = block_size
151
        self.eos_token_id = eos_token_id
152
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
153

154
        self.data = SequenceData(prompt_token_ids)
155
        self.output_logprobs: SampleLogprobs = []
156
        self.output_text = ""
157

Woosuk Kwon's avatar
Woosuk Kwon committed
158
        self.logical_token_blocks: List[LogicalTokenBlock] = []
159
        # Initialize the logical token blocks with the prompt token ids.
160
        self._append_tokens_to_blocks(prompt_token_ids)
161
        self.status = SequenceStatus.WAITING
Woosuk Kwon's avatar
Woosuk Kwon committed
162

163
164
165
166
167
168
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

169
170
171
172
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

173
174
    def hash_of_block(self, logical_idx: int) -> int:
        # Compute the number of tokens in the sequence
175
176
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
177
178
179
180
181
182
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
        return hash(tuple(self.data.get_token_ids()[0:num_tokens]))

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

183
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
186
187
188
189
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

190
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
191
192
        cursor = 0
        while cursor < len(token_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
193
            if not self.logical_token_blocks:
194
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
197

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
198
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
201
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
202
203
204
            last_block.append_tokens(token_ids[cursor:cursor +
                                               num_empty_slots])
            cursor += num_empty_slots
Woosuk Kwon's avatar
Woosuk Kwon committed
205

206
207
208
    def append_token_id(
        self,
        token_id: int,
209
        logprobs: Dict[int, Logprob],
210
    ) -> None:
211
        assert token_id in logprobs
212
        self._append_tokens_to_blocks([token_id])
213
        self.output_logprobs.append(logprobs)
214
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
215

Woosuk Kwon's avatar
Woosuk Kwon committed
216
    def get_len(self) -> int:
217
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
218

219
220
221
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

222
223
224
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
225
    def get_token_ids(self) -> List[int]:
226
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
227

228
    def get_last_token_id(self) -> int:
229
        return self.data.get_last_token_id()
230

231
232
233
234
235
236
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

237
    def get_beam_search_score(self,
238
                              length_penalty: float = 1.0,
239
240
241
242
243
244
245
246
247
248
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
249
            # NOTE: HF implementation does not count the EOS token
250
251
252
253
254
255
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

256
257
258
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

259
260
261
262
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
263

Woosuk Kwon's avatar
Woosuk Kwon committed
264
    def __repr__(self) -> str:
265
266
267
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
                f"num_blocks={len(self.logical_token_blocks)})")
Woosuk Kwon's avatar
Woosuk Kwon committed
268

Woosuk Kwon's avatar
Woosuk Kwon committed
269

Nick Hill's avatar
Nick Hill committed
270
271
272
273
274
275
276
277
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
    generator: Optional = None


Woosuk Kwon's avatar
Woosuk Kwon committed
278
class SequenceGroup:
279
280
281
282
283
284
285
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
286
        lora_request: LoRA request.
287
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
290

    def __init__(
        self,
291
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
292
        seqs: List[Sequence],
293
        sampling_params: SamplingParams,
294
        arrival_time: float,
295
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
296
    ) -> None:
297
        self.request_id = request_id
298
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
299
        self.sampling_params = sampling_params
300
301
302
303
304
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
305
        self.lora_request = lora_request
306
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
307
        self.state = SequenceGroupState()
308
309
310
311
312
313
314
315
316
317
318
319

    @property
    def prompt(self) -> str:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).data.prompt_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
320

321
322
323
324
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

325
326
    def get_last_latency(self, now: float) -> float:
        """Gets last token latency for Request level timings."""
327
328
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
329
330
        return latency

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
        if self.metrics.first_token_time is None:
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
        """Sets the first scheduled time and time in queue for Request level timings."""
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

346
347
348
349
350
351
352
353
354
355
356
357
358
359
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
        if self.sampling_params.use_beam_search:
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
            if self.sampling_params.best_of > self.num_seqs():
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
360
361
            # that are not finished yet.
            return self.num_unfinished_seqs()
362

363
364
365
366
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
367
368
369
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
370

371
372
373
374
375
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

376
377
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
378
379
380

    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        return len(self.get_seqs(status))
381

382
383
384
385
386
387
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

388
    def find(self, seq_id: int) -> Sequence:
389
390
391
392
393
394
395
396
397
398
399
400
401
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
402

Woosuk Kwon's avatar
Woosuk Kwon committed
403
    def is_finished(self) -> bool:
404
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
405

Woosuk Kwon's avatar
Woosuk Kwon committed
406
    def __repr__(self) -> str:
407
408
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
409
                f"num_seqs={len(self.seqs_dict)})")
410
411


412
class SequenceGroupMetadata:
413
414
415
416
417
418
419
420
421
    """Metadata for a sequence group. Used to create `InputMetadata`.

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
Nick Hill's avatar
Nick Hill committed
422
        state: Internal state tied to this sequence group.
423
        lora_request: LoRA request.
424
    """
425
426
427

    def __init__(
        self,
428
        request_id: str,
429
        is_prompt: bool,
430
        seq_data: Dict[int, SequenceData],
431
        sampling_params: SamplingParams,
432
        block_tables: Dict[int, List[int]],
433
        lora_request: Optional[LoRARequest] = None,
434
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
435
        state: Optional[SequenceGroupState] = None,
436
    ) -> None:
437
        self.request_id = request_id
438
        self.is_prompt = is_prompt
439
        self.seq_data = seq_data
440
441
        self.sampling_params = sampling_params
        self.block_tables = block_tables
442
        self.lora_request = lora_request
443
        self.computed_block_nums = computed_block_nums
Nick Hill's avatar
Nick Hill committed
444
        self.state = SequenceGroupState() if state is None else state
445

446
447
448
449
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

450

Zhuohan Li's avatar
Zhuohan Li committed
451
class SequenceOutput:
452
453
454
455
456
457
458
459
460
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
461
462
463
464
465

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
466
        logprobs: Dict[int, Logprob],
467
468
469
470
471
472
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
473
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
474
475
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
476

477
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
478
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
479
            raise NotImplementedError()
480
481
482
483
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
484
485


Zhuohan Li's avatar
Zhuohan Li committed
486
487
class SequenceGroupOutput:
    """The model output associated with a sequence group."""
488
489
490

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
491
        samples: List[SequenceOutput],
492
493
494
495
496
497
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
498
        return (f"SequenceGroupOutput(samples={self.samples}, "
499
500
                f"prompt_logprobs={self.prompt_logprobs})")

501
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
502
        if not isinstance(other, SequenceGroupOutput):
503
504
505
506
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

507

Zhuohan Li's avatar
Zhuohan Li committed
508
# For each sequence group, we generate a list of SequenceOutput object,
509
# each of which contains one possible candidate for the next token.
Zhuohan Li's avatar
Zhuohan Li committed
510
SamplerOutput = List[SequenceGroupOutput]