sequence.py 30.7 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.block import LogicalTokenBlock
9
from vllm.lora.request import LoRARequest
10
from vllm.pooling_params import PoolingParams
11
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
14
if TYPE_CHECKING:
    import torch
15

16
17
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

18
19
20

@dataclass
class Logprob:
21
22
23
24
25
26
27
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
28
    logprob: float
29
    rank: Optional[int] = None
30
31
32
    decoded_token: Optional[str] = None


33
34
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
35
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
36
# {token_id -> logprob} for each sequence group.
37
SampleLogprobs = List[Dict[int, Logprob]]
38

Woosuk Kwon's avatar
Woosuk Kwon committed
39
40

class SequenceStatus(enum.Enum):
41
    """Status of a sequence."""
42
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
43
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
44
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
45
46
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
47
    FINISHED_ABORTED = enum.auto()
Lily Liu's avatar
Lily Liu committed
48
    FINISHED_IGNORED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
49
50
51
52
53
54

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
55
            SequenceStatus.FINISHED_ABORTED,
56
            SequenceStatus.FINISHED_IGNORED,
Zhuohan Li's avatar
Zhuohan Li committed
57
58
59
60
61
62
63
64
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
65
66
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
67
        elif status == SequenceStatus.FINISHED_IGNORED:
68
69
70
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
71
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
75

76

77
78
79
80
81
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


82
83
84
85
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

86
    Attributes:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


101
class SequenceData:
102
103
104
105
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
106
107
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
108
109
110
111
112
113

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
114
115
116
117

    def __init__(
        self,
        prompt_token_ids: List[int],
118
        output_token_ids: Optional[List[int]] = None,
119
    ) -> None:
120
121
122
        if output_token_ids is None:
            output_token_ids = []

123
        self.prompt_token_ids = prompt_token_ids
124
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
125
        self.output_token_ids = output_token_ids
126
        self.cumulative_logprob = 0.0
127
128
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
129
        self._stage: SequenceStage = SequenceStage.PREFILL
130

131
    def append_token_id(self, token_id: int, logprob: float) -> None:
132
133
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
134
135
136
137

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

138
139
140
    def get_prompt_len(self) -> int:
        return len(self.prompt_token_ids)

141
142
143
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

144
145
146
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

147
148
149
150
151
152
153
154
155
156
157
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
        prompt_length = len(self.prompt_token_ids)
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
                    tuple(self.output_token_ids[:num_tokens - prompt_length]))
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

158
159
160
161
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

162
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
163
164
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
165
166
167
168
169
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
170

171
    def reset_state_for_recompute(self) -> None:
172
173
174
175
176
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
177
        self._stage = SequenceStage.PREFILL
178
179

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
180
        """Return the number of prefill tokens that are not computed."""
181
182
183
184
185
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

186
187
188
189
190
    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

191
    def get_prompt_token_ids(self) -> List[int]:
192
193
        return self.prompt_token_ids

194
    def get_output_token_ids(self) -> List[int]:
195
196
        return self.output_token_ids

197
198
199
200
    @property
    def stage(self) -> SequenceStage:
        return self._stage

201
202
203
    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
204
205
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
206
207


Woosuk Kwon's avatar
Woosuk Kwon committed
208
class Sequence:
209
210
211
212
213
214
215
216
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
        prompt: The prompt of the sequence.
        prompt_token_ids: The token IDs of the prompt.
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
217
        lora_request: LoRA request.
218
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222

    def __init__(
        self,
        seq_id: int,
223
        prompt: str,
224
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
225
        block_size: int,
Cade Daniel's avatar
Cade Daniel committed
226
        eos_token_id: Optional[int] = None,
227
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
    ) -> None:
        self.seq_id = seq_id
230
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
231
        self.block_size = block_size
232
        self.eos_token_id = eos_token_id
233
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
234

235
        self.data: SequenceData = SequenceData(prompt_token_ids)
236
        self.output_logprobs: SampleLogprobs = []
237
        self.output_text = ""
238

Woosuk Kwon's avatar
Woosuk Kwon committed
239
        self.logical_token_blocks: List[LogicalTokenBlock] = []
240
        # Initialize the logical token blocks with the prompt token ids.
241
        self._append_tokens_to_blocks(prompt_token_ids)
242
        self.status = SequenceStatus.WAITING
243
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
244

245
246
247
248
249
250
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

251
252
253
254
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

255
256
257
258
259
260
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

261
    def hash_of_block(self, logical_idx: int) -> int:
262
263
        # TODO This can produce incorrect hash when block size > prompt size

264
        # Compute the number of tokens in the sequence
265
266
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
267
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
268
269
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
270
271
272
273

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

274
275
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
276
        self.data.reset_state_for_recompute()
277

278
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281
282
283
284
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

285
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
286
287
        cursor = 0
        while cursor < len(token_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
288
            if not self.logical_token_blocks:
289
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
290
291
292

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
293
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
297
298
299
            last_block.append_tokens(token_ids[cursor:cursor +
                                               num_empty_slots])
            cursor += num_empty_slots
Woosuk Kwon's avatar
Woosuk Kwon committed
300

301
302
303
    def append_token_id(
        self,
        token_id: int,
304
        logprobs: Dict[int, Logprob],
305
    ) -> None:
306
        assert token_id in logprobs
307
        self._append_tokens_to_blocks([token_id])
308
        self.output_logprobs.append(logprobs)
309
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
310

Woosuk Kwon's avatar
Woosuk Kwon committed
311
    def get_len(self) -> int:
312
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
313

314
315
316
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

317
318
319
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
320
    def get_token_ids(self) -> List[int]:
321
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
322

323
324
325
    def get_prompt_token_ids(self) -> List[int]:
        return self.data.get_prompt_token_ids()

326
    def get_last_token_id(self) -> int:
327
        return self.data.get_last_token_id()
328

329
330
331
332
333
334
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

335
    def get_beam_search_score(self,
336
                              length_penalty: float = 1.0,
337
338
339
340
341
342
343
344
345
346
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
347
            # NOTE: HF implementation does not count the EOS token
348
349
350
351
352
353
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

354
355
356
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

357
358
359
360
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
361

362
363
364
365
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
366
367
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
368
369
370
371
372
373
374
375
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
376
    def __repr__(self) -> str:
377
378
379
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
                f"num_blocks={len(self.logical_token_blocks)})")
Woosuk Kwon's avatar
Woosuk Kwon committed
380

Woosuk Kwon's avatar
Woosuk Kwon committed
381

Nick Hill's avatar
Nick Hill committed
382
383
384
385
386
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
387
    generator: Optional = None  # type: ignore
Nick Hill's avatar
Nick Hill committed
388
389


390
391
class MultiModalData:
    """Multi modal request.
392

393
394
395
396
    Args:
        type: The data type.
        data: The actual data.
        The required shape and semantic meaning of it depends on the vision
397
        language config of the hosted model.
398
399
400
401
402
403
404
405
406
407
408
        See `VisionLanguageConfig` in `config.py`.
    """

    class Type(enum.Enum):
        IMAGE = enum.auto()

    def __init__(self, type: Type, data: "torch.Tensor"):
        self.type = type
        self.data = data


Woosuk Kwon's avatar
Woosuk Kwon committed
409
class SequenceGroup:
410
411
412
413
414
415
416
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
417
        lora_request: LoRA request.
418
        multi_modal_data: Multi modal data associated with the request.
419
420
421
422
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
423
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
424
425
426

    def __init__(
        self,
427
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
428
        seqs: List[Sequence],
429
        arrival_time: float,
430
        sampling_params: Optional[SamplingParams] = None,
431
        lora_request: Optional[LoRARequest] = None,
432
        multi_modal_data: Optional[MultiModalData] = None,
433
434
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
435
    ) -> None:
436
        self.request_id = request_id
437
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
438
        self.sampling_params = sampling_params
439
440
441
442
443
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
444
        self.lora_request = lora_request
445
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
446
        self.state = SequenceGroupState()
447
        self.multi_modal_data = multi_modal_data
448
449
        self.embeddings = embeddings
        self.pooling_params = pooling_params
450
451
452
453
454
455
456
457
458
459
460
461

    @property
    def prompt(self) -> str:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).data.prompt_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
462

463
464
465
466
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

467
468
469
470
471
472
473
474
475
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
476
477
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
478
479
        return latency

480
481
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
482
483
484
485
486
487
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
                and self.get_seqs()[0].get_output_len() == 1):
488
489
490
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
491
492
        """Sets the first scheduled time and time in queue for Request
        level timings."""
493
494
495
496
497
498
499
500
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

501
502
503
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
504
        if self.sampling_params and self.sampling_params.use_beam_search:
505
506
507
508
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
509
510
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
511
512
513
514
515
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
516
517
            # that are not finished yet.
            return self.num_unfinished_seqs()
518

519
520
521
522
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
523
524
525
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
526

527
528
529
530
531
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

532
533
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
534

535
536
537
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        for seq in self.seqs_dict.values():
538
539
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
540
541

    def get_num_uncomputed_tokens(self) -> int:
542
543
        num_uncomputed_tokens = 0
        for seq in self.get_seqs():
544
545
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
546
        return num_uncomputed_tokens
547

548
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
549
550
551
552
553
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
            return len(self.seqs_dict)

554
        return len(self.get_seqs(status))
555

556
557
558
559
560
561
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

562
    def find(self, seq_id: int) -> Sequence:
563
564
565
566
567
568
569
570
571
572
573
574
575
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
576

Woosuk Kwon's avatar
Woosuk Kwon committed
577
    def is_finished(self) -> bool:
578
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
579

580
    def is_prefill(self) -> bool:
581
        # Every sequence should be in the same stage.
582
583
        return self.get_seqs()[0].is_prefill()

Woosuk Kwon's avatar
Woosuk Kwon committed
584
    def __repr__(self) -> str:
585
586
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
587
                f"num_seqs={len(self.seqs_dict)})")
588
589


590
class SequenceGroupMetadata:
591
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
592
593
594
595
596
597
598
599

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
600
601
602
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
603
604
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
605
        lora_request: LoRA request.
606
607
608
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
        state: Internal state tied to this sequence group.
609
        multi_modal_data: Multi modal data.
610
    """
611
612
613

    def __init__(
        self,
614
        request_id: str,
615
        is_prompt: bool,
616
        seq_data: Dict[int, SequenceData],
617
        sampling_params: SamplingParams,
618
        block_tables: Dict[int, List[int]],
619
        do_sample: bool = True,
620
        pooling_params: Optional[PoolingParams] = None,
621
        token_chunk_size: Optional[int] = None,
622
        lora_request: Optional[LoRARequest] = None,
623
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
624
        state: Optional[SequenceGroupState] = None,
625
        multi_modal_data: Optional[MultiModalData] = None,
626
    ) -> None:
627
        self.request_id = request_id
628
        self.is_prompt = is_prompt
629
        self.seq_data = seq_data
630
631
        self.sampling_params = sampling_params
        self.block_tables = block_tables
632
        self.pooling_params = pooling_params
633
        self.lora_request = lora_request
634
        self.computed_block_nums = computed_block_nums
635
        self.multi_modal_data = multi_modal_data
Nick Hill's avatar
Nick Hill committed
636
        self.state = SequenceGroupState() if state is None else state
637
        self._token_chunk_size = token_chunk_size
638
        self.do_sample = do_sample
639

640
641
642
643
644
645
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

646
647
648
649
650
        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
651

652
653
654
655
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

656
    @property
657
    def token_chunk_size(self) -> Optional[int]:
658
659
660
        """Return the number of tokens to be processed (chunk size)."""
        return self._token_chunk_size

661

Zhuohan Li's avatar
Zhuohan Li committed
662
class SequenceOutput:
663
664
665
666
667
668
669
670
671
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
672
673
674
675
676

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
677
        logprobs: Dict[int, Logprob],
678
679
680
681
682
683
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
684
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
685
686
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
687

688
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
689
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
690
            raise NotImplementedError()
691
692
693
694
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
695
696


697
698
699
700
701
702
703
704
705
706
707
708
709
710
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
711
712
713

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
714
        samples: List[SequenceOutput],
715
716
717
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
718
        # Prompt logprob for each prompt query token.
719
720
721
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
722
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
723
724
                f"prompt_logprobs={self.prompt_logprobs})")

725
    def __eq__(self, other: object) -> bool:
726
        if not isinstance(other, CompletionSequenceGroupOutput):
727
728
729
730
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

731

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


751
752
753
754
755
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

756
    This data structure implements methods, so it can be used like a list, but
757
758
759
    also has optional fields for device tensors.
    """

760
    outputs: List[CompletionSequenceGroupOutput]
761
762
763
764

    # On-device tensor containing probabilities of each token.
    sampled_token_probs: Optional["torch.Tensor"] = None

765
766
767
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    # On-device tensor containing the sampled token ids.
    sampled_token_ids: Optional["torch.Tensor"] = None

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
786
787
788
789
790
791
792
793
794
795
796
797
798

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
799
800


801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


822
823
824
825
826
@dataclass
class ExecuteModelRequest:
    """The model execution request."""
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
827
828
829
830
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
831
832
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
        )