sequence.py 30.1 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.block import LogicalTokenBlock
9
from vllm.lora.request import LoRARequest
10
from vllm.pooling_params import PoolingParams
11
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
14
if TYPE_CHECKING:
    import torch
15

16
17
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

18
19
20

@dataclass
class Logprob:
21
22
23
24
25
26
27
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
28
    logprob: float
29
    rank: Optional[int] = None
30
31
32
    decoded_token: Optional[str] = None


33
34
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
35
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
36
# {token_id -> logprob} for each sequence group.
37
SampleLogprobs = List[Dict[int, Logprob]]
38

Woosuk Kwon's avatar
Woosuk Kwon committed
39
40

class SequenceStatus(enum.Enum):
41
    """Status of a sequence."""
42
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
43
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
44
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
45
46
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
47
    FINISHED_ABORTED = enum.auto()
Lily Liu's avatar
Lily Liu committed
48
    FINISHED_IGNORED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
49
50
51
52
53
54

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
55
            SequenceStatus.FINISHED_ABORTED,
56
            SequenceStatus.FINISHED_IGNORED,
Zhuohan Li's avatar
Zhuohan Li committed
57
58
59
60
61
62
63
64
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
65
66
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
67
        elif status == SequenceStatus.FINISHED_IGNORED:
68
69
70
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
71
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
75

76

77
78
79
80
81
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


82
83
84
85
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

86
    Attributes:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


101
class SequenceData:
102
103
104
105
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
106
107
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
108
109
110
111
112
113

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
114
115
116
117

    def __init__(
        self,
        prompt_token_ids: List[int],
118
        output_token_ids: Optional[List[int]] = None,
119
    ) -> None:
120
121
122
        if output_token_ids is None:
            output_token_ids = []

123
        self.prompt_token_ids = prompt_token_ids
124
        self.output_token_ids = output_token_ids
125
        self.cumulative_logprob = 0.0
126
127
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
128
        self._stage: SequenceStage = SequenceStage.PREFILL
129

130
    def append_token_id(self, token_id: int, logprob: float) -> None:
131
132
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
133
134
135
136

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

137
138
139
    def get_prompt_len(self) -> int:
        return len(self.prompt_token_ids)

140
141
142
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

143
144
145
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

146
147
148
149
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

150
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
151
152
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
153
154
155
156
157
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
158

159
    def reset_state_for_recompute(self) -> None:
160
161
162
163
164
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
165
        self._stage = SequenceStage.PREFILL
166
167

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
168
        """Return the number of prefill tokens that are not computed."""
169
170
171
172
173
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

174
175
176
177
178
    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

179
    def get_prompt_token_ids(self) -> List[int]:
180
181
        return self.prompt_token_ids

182
    def get_output_token_ids(self) -> List[int]:
183
184
        return self.output_token_ids

185
186
187
188
    @property
    def stage(self) -> SequenceStage:
        return self._stage

189
190
191
    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
192
193
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
194
195


Woosuk Kwon's avatar
Woosuk Kwon committed
196
class Sequence:
197
198
199
200
201
202
203
204
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
        prompt: The prompt of the sequence.
        prompt_token_ids: The token IDs of the prompt.
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
205
        lora_request: LoRA request.
206
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
209
210

    def __init__(
        self,
        seq_id: int,
211
        prompt: str,
212
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        block_size: int,
Cade Daniel's avatar
Cade Daniel committed
214
        eos_token_id: Optional[int] = None,
215
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
    ) -> None:
        self.seq_id = seq_id
218
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
219
        self.block_size = block_size
220
        self.eos_token_id = eos_token_id
221
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
222

223
        self.data: SequenceData = SequenceData(prompt_token_ids)
224
        self.output_logprobs: SampleLogprobs = []
225
        self.output_text = ""
226

Woosuk Kwon's avatar
Woosuk Kwon committed
227
        self.logical_token_blocks: List[LogicalTokenBlock] = []
228
        # Initialize the logical token blocks with the prompt token ids.
229
        self._append_tokens_to_blocks(prompt_token_ids)
230
        self.status = SequenceStatus.WAITING
231
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
232

233
234
235
236
237
238
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

239
240
241
242
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

243
244
245
246
247
248
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

249
    def hash_of_block(self, logical_idx: int) -> int:
250
251
        # TODO This can produce incorrect hash when block size > prompt size

252
        # Compute the number of tokens in the sequence
253
254
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
255
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
256
257
        return hash(
            (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
258
259
260
261

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

262
263
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
264
        self.data.reset_state_for_recompute()
265

266
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
269
270
271
272
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

273
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
274
275
        cursor = 0
        while cursor < len(token_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
276
            if not self.logical_token_blocks:
277
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
280

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
281
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
282
283
284
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
285
286
287
            last_block.append_tokens(token_ids[cursor:cursor +
                                               num_empty_slots])
            cursor += num_empty_slots
Woosuk Kwon's avatar
Woosuk Kwon committed
288

289
290
291
    def append_token_id(
        self,
        token_id: int,
292
        logprobs: Dict[int, Logprob],
293
    ) -> None:
294
        assert token_id in logprobs
295
        self._append_tokens_to_blocks([token_id])
296
        self.output_logprobs.append(logprobs)
297
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
298

Woosuk Kwon's avatar
Woosuk Kwon committed
299
    def get_len(self) -> int:
300
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
301

302
303
304
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

305
306
307
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
308
    def get_token_ids(self) -> List[int]:
309
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
310

311
312
313
    def get_prompt_token_ids(self) -> List[int]:
        return self.data.get_prompt_token_ids()

314
    def get_last_token_id(self) -> int:
315
        return self.data.get_last_token_id()
316

317
318
319
320
321
322
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

323
    def get_beam_search_score(self,
324
                              length_penalty: float = 1.0,
325
326
327
328
329
330
331
332
333
334
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
335
            # NOTE: HF implementation does not count the EOS token
336
337
338
339
340
341
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

342
343
344
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

345
346
347
348
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
349

350
351
352
353
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
354
355
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
356
357
358
359
360
361
362
363
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
364
    def __repr__(self) -> str:
365
366
367
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
                f"num_blocks={len(self.logical_token_blocks)})")
Woosuk Kwon's avatar
Woosuk Kwon committed
368

Woosuk Kwon's avatar
Woosuk Kwon committed
369

Nick Hill's avatar
Nick Hill committed
370
371
372
373
374
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
375
    generator: Optional = None  # type: ignore
Nick Hill's avatar
Nick Hill committed
376
377


378
379
class MultiModalData:
    """Multi modal request.
380

381
382
383
384
    Args:
        type: The data type.
        data: The actual data.
        The required shape and semantic meaning of it depends on the vision
385
        language config of the hosted model.
386
387
388
389
390
391
392
393
394
395
396
        See `VisionLanguageConfig` in `config.py`.
    """

    class Type(enum.Enum):
        IMAGE = enum.auto()

    def __init__(self, type: Type, data: "torch.Tensor"):
        self.type = type
        self.data = data


Woosuk Kwon's avatar
Woosuk Kwon committed
397
class SequenceGroup:
398
399
400
401
402
403
404
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
405
        lora_request: LoRA request.
406
        multi_modal_data: Multi modal data associated with the request.
407
408
409
410
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
411
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
414

    def __init__(
        self,
415
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
416
        seqs: List[Sequence],
417
        arrival_time: float,
418
        sampling_params: Optional[SamplingParams] = None,
419
        lora_request: Optional[LoRARequest] = None,
420
        multi_modal_data: Optional[MultiModalData] = None,
421
422
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
423
    ) -> None:
424
        self.request_id = request_id
425
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
426
        self.sampling_params = sampling_params
427
428
429
430
431
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
432
        self.lora_request = lora_request
433
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
434
        self.state = SequenceGroupState()
435
        self.multi_modal_data = multi_modal_data
436
437
        self.embeddings = embeddings
        self.pooling_params = pooling_params
438
439
440
441
442
443
444
445
446
447
448
449

    @property
    def prompt(self) -> str:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).data.prompt_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
450

451
452
453
454
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

455
456
457
458
459
460
461
462
463
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
464
465
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
466
467
        return latency

468
469
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
470
471
472
473
474
475
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
                and self.get_seqs()[0].get_output_len() == 1):
476
477
478
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
479
480
        """Sets the first scheduled time and time in queue for Request
        level timings."""
481
482
483
484
485
486
487
488
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

489
490
491
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
492
        if self.sampling_params and self.sampling_params.use_beam_search:
493
494
495
496
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
497
498
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
499
500
501
502
503
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
504
505
            # that are not finished yet.
            return self.num_unfinished_seqs()
506

507
508
509
510
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
511
512
513
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
514

515
516
517
518
519
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

520
521
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
522

523
524
525
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        for seq in self.seqs_dict.values():
526
527
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
528
529

    def get_num_uncomputed_tokens(self) -> int:
530
531
        num_uncomputed_tokens = 0
        for seq in self.get_seqs():
532
533
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
534
        return num_uncomputed_tokens
535

536
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
537
538
539
540
541
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
            return len(self.seqs_dict)

542
        return len(self.get_seqs(status))
543

544
545
546
547
548
549
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

550
    def find(self, seq_id: int) -> Sequence:
551
552
553
554
555
556
557
558
559
560
561
562
563
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
564

Woosuk Kwon's avatar
Woosuk Kwon committed
565
    def is_finished(self) -> bool:
566
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
567

568
    def is_prefill(self) -> bool:
569
        # Every sequence should be in the same stage.
570
571
        return self.get_seqs()[0].is_prefill()

Woosuk Kwon's avatar
Woosuk Kwon committed
572
    def __repr__(self) -> str:
573
574
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
575
                f"num_seqs={len(self.seqs_dict)})")
576
577


578
class SequenceGroupMetadata:
579
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
580
581
582
583
584
585
586
587

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
588
589
590
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
591
592
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
593
        lora_request: LoRA request.
594
595
596
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
        state: Internal state tied to this sequence group.
597
        multi_modal_data: Multi modal data.
598
    """
599
600
601

    def __init__(
        self,
602
        request_id: str,
603
        is_prompt: bool,
604
        seq_data: Dict[int, SequenceData],
605
        sampling_params: SamplingParams,
606
        block_tables: Dict[int, List[int]],
607
        do_sample: bool = True,
608
        pooling_params: Optional[PoolingParams] = None,
609
        token_chunk_size: Optional[int] = None,
610
        lora_request: Optional[LoRARequest] = None,
611
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
612
        state: Optional[SequenceGroupState] = None,
613
        multi_modal_data: Optional[MultiModalData] = None,
614
    ) -> None:
615
        self.request_id = request_id
616
        self.is_prompt = is_prompt
617
        self.seq_data = seq_data
618
619
        self.sampling_params = sampling_params
        self.block_tables = block_tables
620
        self.pooling_params = pooling_params
621
        self.lora_request = lora_request
622
        self.computed_block_nums = computed_block_nums
623
        self.multi_modal_data = multi_modal_data
Nick Hill's avatar
Nick Hill committed
624
        self.state = SequenceGroupState() if state is None else state
625
        self._token_chunk_size = token_chunk_size
626
        self.do_sample = do_sample
627

628
629
630
631
632
633
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

634
635
636
637
638
        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
639

640
641
642
643
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

644
    @property
645
    def token_chunk_size(self) -> Optional[int]:
646
647
648
        """Return the number of tokens to be processed (chunk size)."""
        return self._token_chunk_size

649

Zhuohan Li's avatar
Zhuohan Li committed
650
class SequenceOutput:
651
652
653
654
655
656
657
658
659
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
660
661
662
663
664

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
665
        logprobs: Dict[int, Logprob],
666
667
668
669
670
671
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
672
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
673
674
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
675

676
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
677
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
678
            raise NotImplementedError()
679
680
681
682
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
683
684


685
686
687
688
689
690
691
692
693
694
695
696
697
698
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
699
700
701

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
702
        samples: List[SequenceOutput],
703
704
705
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
706
        # Prompt logprob for each prompt query token.
707
708
709
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
710
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
711
712
                f"prompt_logprobs={self.prompt_logprobs})")

713
    def __eq__(self, other: object) -> bool:
714
        if not isinstance(other, CompletionSequenceGroupOutput):
715
716
717
718
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

719

720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


739
740
741
742
743
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

744
    This data structure implements methods, so it can be used like a list, but
745
746
747
    also has optional fields for device tensors.
    """

748
    outputs: List[CompletionSequenceGroupOutput]
749
750
751
752

    # On-device tensor containing probabilities of each token.
    sampled_token_probs: Optional["torch.Tensor"] = None

753
754
755
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    # On-device tensor containing the sampled token ids.
    sampled_token_ids: Optional["torch.Tensor"] = None

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
774
775
776
777
778
779
780
781
782
783
784
785
786

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
787
788


789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


810
811
812
813
814
@dataclass
class ExecuteModelRequest:
    """The model execution request."""
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
815
816
817
818
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
819
820
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
        )