sequence.py 33 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
import math
5
from abc import ABC, abstractmethod
6
from dataclasses import dataclass, field
7
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
import torch

11
from vllm.inputs import LLMInputs
12
from vllm.lora.request import LoRARequest
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
15

16
if TYPE_CHECKING:
17
    from vllm.multimodal import MultiModalData
18
19
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

20
21
22

@dataclass
class Logprob:
23
24
25
26
27
28
29
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
30
    logprob: float
31
    rank: Optional[int] = None
32
33
34
    decoded_token: Optional[str] = None


35
36
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
37
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
38
# {token_id -> logprob} for each sequence group.
39
SampleLogprobs = List[Dict[int, Logprob]]
40

Woosuk Kwon's avatar
Woosuk Kwon committed
41
42

class SequenceStatus(enum.Enum):
43
    """Status of a sequence."""
44
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
45
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
46
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
47
48
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
49
    FINISHED_ABORTED = enum.auto()
Lily Liu's avatar
Lily Liu committed
50
    FINISHED_IGNORED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
51
52
53
54
55
56

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
57
            SequenceStatus.FINISHED_ABORTED,
58
            SequenceStatus.FINISHED_IGNORED,
Zhuohan Li's avatar
Zhuohan Li committed
59
60
61
62
63
64
65
66
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
67
68
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
69
        elif status == SequenceStatus.FINISHED_IGNORED:
70
71
72
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
73
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
74
75
76
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
77

78

79
80
81
82
83
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


84
85
86
87
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

88
    Attributes:
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


103
class SequenceData:
104
105
106
107
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
108
109
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
110
111
112
113
114
115

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
116
117
118
119

    def __init__(
        self,
        prompt_token_ids: List[int],
120
        output_token_ids: Optional[List[int]] = None,
121
    ) -> None:
122
123
124
        if output_token_ids is None:
            output_token_ids = []

125
        self.prompt_token_ids = prompt_token_ids
126
        self._prompt_token_ids_tuple = tuple(prompt_token_ids)
127
        self.output_token_ids = output_token_ids
128
        self.cumulative_logprob = 0.0
129
130
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
131
        self._stage: SequenceStage = SequenceStage.PREFILL
132

133
    def append_token_id(self, token_id: int, logprob: float) -> None:
134
135
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
136
137
138
139

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

140
141
142
    def get_prompt_len(self) -> int:
        return len(self.prompt_token_ids)

143
144
145
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

146
147
148
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

149
150
151
152
153
154
155
156
157
158
159
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
        prompt_length = len(self.prompt_token_ids)
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
                    tuple(self.output_token_ids[:num_tokens - prompt_length]))
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

160
161
162
163
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

164
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
165
166
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
167
168
169
170
171
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
172

173
    def reset_state_for_recompute(self) -> None:
174
175
176
177
178
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
179
        self._stage = SequenceStage.PREFILL
180
181

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
182
        """Return the number of prefill tokens that are not computed."""
183
184
185
186
187
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

188
189
190
191
192
    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

193
    def get_prompt_token_ids(self) -> List[int]:
194
195
        return self.prompt_token_ids

196
    def get_output_token_ids(self) -> List[int]:
197
198
        return self.output_token_ids

199
200
201
202
    @property
    def stage(self) -> SequenceStage:
        return self._stage

203
204
205
    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
206
207
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
208
209


Woosuk Kwon's avatar
Woosuk Kwon committed
210
class Sequence:
211
212
213
214
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
215
        inputs: The inputs of the sequence.
216
217
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
218
        lora_request: LoRA request.
219
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
222
223

    def __init__(
        self,
        seq_id: int,
224
        inputs: LLMInputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
225
        block_size: int,
Cade Daniel's avatar
Cade Daniel committed
226
        eos_token_id: Optional[int] = None,
227
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
    ) -> None:
        self.seq_id = seq_id
230
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
231
        self.block_size = block_size
232
        self.eos_token_id = eos_token_id
233
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
234

235
        self.data = SequenceData(self.prompt_token_ids)
236
        self.output_logprobs: SampleLogprobs = []
237
        self.output_text = ""
238

239
        self.status = SequenceStatus.WAITING
240
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
241

242
243
244
245
246
247
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

248
249
250
251
    @property
    def n_blocks(self) -> int:
        return math.ceil(self.get_len() / self.block_size)

252
253
    @property
    def prompt(self) -> Optional[str]:
254
        return self.inputs.get("prompt")
255
256
257
258
259
260
261

    @property
    def prompt_token_ids(self) -> List[int]:
        return self.inputs["prompt_token_ids"]

    @property
    def multi_modal_data(self) -> Optional["MultiModalData"]:
262
        return self.inputs.get("multi_modal_data")
263

264
265
266
267
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

268
269
270
271
272
273
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

274
    def hash_of_block(self, logical_idx: int) -> int:
275
276
        # TODO This can produce incorrect hash when block size > prompt size

277
        # Compute the number of tokens in the sequence
278
279
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
280
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
281
282
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
283
284
285
286

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

287
288
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
289
        self.data.reset_state_for_recompute()
290

291
292
293
    def append_token_id(
        self,
        token_id: int,
294
        logprobs: Dict[int, Logprob],
295
    ) -> None:
296
297
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
298
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
299

Woosuk Kwon's avatar
Woosuk Kwon committed
300
    def get_len(self) -> int:
301
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
302

303
304
305
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

306
307
308
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
309
    def get_token_ids(self) -> List[int]:
310
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
311

312
313
314
    def get_prompt_token_ids(self) -> List[int]:
        return self.data.get_prompt_token_ids()

315
    def get_last_token_id(self) -> int:
316
        return self.data.get_last_token_id()
317

318
319
320
321
322
323
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

324
    def get_beam_search_score(self,
325
                              length_penalty: float = 1.0,
326
327
328
329
330
331
332
333
334
335
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
336
            # NOTE: HF implementation does not count the EOS token
337
338
339
340
341
342
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

343
344
345
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

346
347
348
349
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
350

351
352
353
354
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
355
356
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
357
358
359
360
361
362
363
364
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
365
    def __repr__(self) -> str:
366
367
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
368
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
369

Woosuk Kwon's avatar
Woosuk Kwon committed
370

Nick Hill's avatar
Nick Hill committed
371
372
373
374
375
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
376
    generator: Optional = None  # type: ignore
Nick Hill's avatar
Nick Hill committed
377
378


Woosuk Kwon's avatar
Woosuk Kwon committed
379
class SequenceGroup:
380
381
382
383
384
385
386
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
387
        lora_request: LoRA request.
388
389
390
391
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
392
393
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
394
        trace_headers: OpenTelemetry trace headers.
395
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398

    def __init__(
        self,
399
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
400
        seqs: List[Sequence],
401
        arrival_time: float,
402
        sampling_params: Optional[SamplingParams] = None,
403
        lora_request: Optional[LoRARequest] = None,
404
405
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
406
        encoder_seq: Optional[Sequence] = None,
407
        trace_headers: Optional[Dict[str, str]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
408
    ) -> None:
409
        self.request_id = request_id
410
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
411
        self.sampling_params = sampling_params
412
413
414
415
416
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
417
        self.lora_request = lora_request
418
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
419
        self.state = SequenceGroupState()
420
421
        self.embeddings = embeddings
        self.pooling_params = pooling_params
422
        self.encoder_seq = encoder_seq
423
        self.trace_headers = trace_headers
424
425

    @property
426
    def prompt(self) -> Optional[str]:
427
428
429
430
431
432
433
434
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
435
436
437
        return next(iter(self.seqs_dict.values())).prompt_token_ids

    @property
438
    def multi_modal_data(self) -> Optional["MultiModalData"]:
439
440
441
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
442

443
444
445
446
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

447
448
449
450
451
452
453
454
455
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
456
457
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
458
459
        return latency

460
461
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
462
463
464
465
466
467
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
                and self.get_seqs()[0].get_output_len() == 1):
468
469
470
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
471
472
        """Sets the first scheduled time and time in queue for Request
        level timings."""
473
474
475
476
477
478
479
480
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

481
482
483
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
484
        if self.sampling_params and self.sampling_params.use_beam_search:
485
486
487
488
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
489
490
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
491
492
493
494
495
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
496
497
            # that are not finished yet.
            return self.num_unfinished_seqs()
498

499
500
501
502
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
503
504
505
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
506

507
508
509
510
511
512
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

513
514
515
516
517
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

518
519
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
520

521
522
523
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        for seq in self.seqs_dict.values():
524
525
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
526
527

    def get_num_uncomputed_tokens(self) -> int:
528
529
        num_uncomputed_tokens = 0
        for seq in self.get_seqs():
530
531
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
532
        return num_uncomputed_tokens
533

534
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
535
536
537
538
539
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
            return len(self.seqs_dict)

540
        return len(self.get_seqs(status))
541

542
543
544
545
546
547
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

548
    def find(self, seq_id: int) -> Sequence:
549
550
551
552
553
554
555
556
557
558
559
560
561
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
562

Woosuk Kwon's avatar
Woosuk Kwon committed
563
    def is_finished(self) -> bool:
564
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
565

566
    def is_prefill(self) -> bool:
567
        # Every sequence should be in the same stage.
568
569
        return self.get_seqs()[0].is_prefill()

Woosuk Kwon's avatar
Woosuk Kwon committed
570
    def __repr__(self) -> str:
571
572
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
573
                f"num_seqs={len(self.seqs_dict)})")
574
575


576
class SequenceGroupMetadata:
577
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
578
579
580
581
582
583
584
585

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
586
587
588
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
589
590
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
591
        lora_request: LoRA request.
592
593
594
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
        state: Internal state tied to this sequence group.
595
        multi_modal_data: Multi modal data.
596
597
598
599
600
601
602
603
604
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
605
    """
606
607
608

    def __init__(
        self,
609
        request_id: str,
610
        is_prompt: bool,
611
        seq_data: Dict[int, SequenceData],
612
        sampling_params: SamplingParams,
613
        block_tables: Dict[int, List[int]],
614
        do_sample: bool = True,
615
        pooling_params: Optional[PoolingParams] = None,
616
        token_chunk_size: Optional[int] = None,
617
        lora_request: Optional[LoRARequest] = None,
618
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
619
        state: Optional[SequenceGroupState] = None,
620
        multi_modal_data: Optional["MultiModalData"] = None,
621
622
        encoder_seq_data: Optional[SequenceData] = None,
        cross_block_table: Optional[List[int]] = None,
623
    ) -> None:
624
        self.request_id = request_id
625
        self.is_prompt = is_prompt
626
        self.seq_data = seq_data
627
628
        self.sampling_params = sampling_params
        self.block_tables = block_tables
629
        self.pooling_params = pooling_params
630
        self.lora_request = lora_request
631
        self.computed_block_nums = computed_block_nums
632
        self.multi_modal_data = multi_modal_data
Nick Hill's avatar
Nick Hill committed
633
        self.state = SequenceGroupState() if state is None else state
634
635
        self.encoder_seq_data = encoder_seq_data
        self.cross_block_table = cross_block_table
636
        self._token_chunk_size = token_chunk_size
637
        self.do_sample = do_sample
638

639
640
641
642
643
644
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

645
646
647
648
649
        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
650

651
652
653
654
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

655
    @property
656
    def token_chunk_size(self) -> int:
657
        """Return the number of tokens to be processed (chunk size)."""
658
        assert self._token_chunk_size is not None
659
660
        return self._token_chunk_size

661

Zhuohan Li's avatar
Zhuohan Li committed
662
class SequenceOutput:
663
664
665
666
667
668
669
670
671
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
672
673
674
675
676

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
677
        logprobs: Dict[int, Logprob],
678
679
680
681
682
683
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
684
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
685
686
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
687

688
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
689
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
690
            raise NotImplementedError()
691
692
693
694
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
695
696


697
698
699
700
701
702
703
704
705
706
707
708
709
710
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
711
712
713

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
714
        samples: List[SequenceOutput],
715
716
717
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
718
        # Prompt logprob for each prompt query token.
719
720
721
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
722
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
723
724
                f"prompt_logprobs={self.prompt_logprobs})")

725
    def __eq__(self, other: object) -> bool:
726
        if not isinstance(other, CompletionSequenceGroupOutput):
727
728
729
730
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

731

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


751
752
753
754
755
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

756
    This data structure implements methods, so it can be used like a list, but
757
758
759
    also has optional fields for device tensors.
    """

760
    outputs: List[CompletionSequenceGroupOutput]
761
762

    # On-device tensor containing probabilities of each token.
763
    sampled_token_probs: Optional[torch.Tensor] = None
764

765
766
767
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

768
    # On-device tensor containing the sampled token ids.
769
    sampled_token_ids: Optional[torch.Tensor] = None
770
771
772
773

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

774
775
776
    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

777
778
779
780
781
782
783
784
785
786
787
788
    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
789
790
791
792
793
794
795
796
797
798
799
800
801

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
802
803


804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


class HiddenStates:
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""

    def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
                 hidden_states: torch.Tensor):
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
        self.hidden_states: torch.Tensor = hidden_states

    def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
               hidden_states: torch.Tensor) -> None:
        """Update hidden states from target model invocation."""
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        """Prune to provided list of sequence ids."""
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self.seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
            self.hidden_states = self.hidden_states[index]
            self.seq_ids = seq_ids


865
866
@dataclass
class ExecuteModelRequest:
867
868
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
869
870
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
871
872
873
874
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
875
876
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
877
878
879
880
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
881
882
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
883
884
885
886
887
888
889
890
891
892
893
894

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
895
            previous_hidden_states=self.previous_hidden_states,
896
        )