sequence.py 60.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Sequence and its related classes."""
3
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import enum
5
from abc import ABC, abstractmethod
6
from array import array
7
from collections import defaultdict
8
9
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
10
from dataclasses import dataclass, field
11
from functools import reduce
12
from typing import Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
import msgspec
15
16
import torch

17
from vllm.inputs import SingletonInputs
18
from vllm.lora.request import LoRARequest
19
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
20
from vllm.pooling_params import PoolingParams
21
from vllm.prompt_adapter.request import PromptAdapterRequest
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
25

26
27
VLLM_INVALID_TOKEN_ID = -1

28

29
def array_full(token_id: int, count: int):
30
    """[`array`][] equivalent of [numpy.full][]."""
31
32
33
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


34
35
36
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
37
38
@dataclass
class Logprob:
39
40
41
42
43
44
45
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
46
    logprob: float
47
    rank: Optional[int] = None
48
49
50
    decoded_token: Optional[str] = None


51
52
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
53
PromptLogprobs = list[Optional[dict[int, Logprob]]]
54
# {token_id -> logprob} for each sequence group.
55
SampleLogprobs = list[dict[int, Logprob]]
56

Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
class SequenceStatus(enum.IntEnum):
59
    """Status of a sequence."""
60
61
62
63
64
65
66
67
68
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
72
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
73
74
75
76
77
78
79

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
80
81
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
82
        elif status == SequenceStatus.FINISHED_IGNORED:
83
84
85
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
86
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
87
88
89
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
90

91

92
93
94
95
96
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


97
98
99
100
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

101
    Attributes:
102
103
104
105
106
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
107
108
109
110
111
112
113
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
114
        spec_token_acceptance_counts: number of accepted speculative tokens at
115
                                      each position; the first token is from
116
                                      the target model and is always accepted;
117
                                      e.g., when it's [10, 8, 4, 2] for a req,
118
                                      it means there were 10 forward passes in
119
120
                                      total, and there were 8, 4, 2 accepted
                                      tokens at 1st, 2nd, 3rd speculation step.
121
122
123
124
125
126
127
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
128
129
130
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
131
    spec_token_acceptance_counts: Optional[list[int]] = None
132
133


134
135
136
137
138
139
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
140
    new_output_token_ids: list[int]
141
142
143
144
145
146
147
148
149
150
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
151
152
153
154
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
155
156
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
157
158
159
160
161
162

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
163
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
164
165
166
167
168
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

169
170
171
    _prompt_embeds: Optional[torch.Tensor] = None
    _output_embeds: Optional[torch.Tensor] = None

172
173
    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
174
    _prompt_token_ids_tuple: tuple[int,
175
176
177
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
178
179
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
180
    _stage: SequenceStage = SequenceStage.PREFILL
181
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
182
    _cached_all_token_embeds: Optional[torch.Tensor] = None
183
184
185

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
186
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
187

188
189
190
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

191
    @staticmethod
192
    def from_prompt_token_counts(
193
            *token_counts: tuple[int, int]) -> "SequenceData":
194
        """
195
196
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        by concatenating prompt token sequences.
197
198

        Each tuple represents one token sequence, expressed in the form
199
        `(token_id, count)`.
200
        """
201
        if len(token_counts) == 0:
202
203
            return SequenceData.from_seqs([])

204
205
206
207
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
208

209
        return SequenceData(prompt_token_ids_arr)
210
211
212
213
214

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
215
216
        *,
        prompt_embeds: Optional[torch.Tensor] = None,
217
    ) -> "SequenceData":
218
        """
219
220
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        from prompt and output token sequences.
221
        """
222
223
224
225
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
226
227
            return SequenceData(prompt_token_ids_arr,
                                _prompt_embeds=prompt_embeds)
228
229
230
231
232

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
233
234
                            _output_token_ids=output_token_ids_arr,
                            _prompt_embeds=prompt_embeds)
235

236
237
238
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
239
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
240
            self._prompt_token_ids)
241
        self._update_cached_all_tokens()
242
243
        if self._prompt_embeds is not None:
            self._update_cached_all_token_embeds()
244
245

    def _update_cached_all_tokens(self):
246
247
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
248
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
249
                                                     self._output_token_ids)
250

251
252
253
254
255
256
257
    def _update_cached_all_token_embeds(self):
        assert isinstance(self._prompt_embeds, torch.Tensor)
        self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
        if self._output_embeds is not None:
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds, self._output_embeds), dim=0)

258
259
260
261
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

262
    @property
263
    def prompt_token_ids(self) -> tuple[int, ...]:
264
265
266
267
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
268
        raise NotImplementedError
269

270
271
    @property
    def prompt_token_ids_array(self) -> array:
272
273
274
275
276
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
277
278
        return self._prompt_token_ids

279
    @property
280
    def output_token_ids(self) -> tuple[int, ...]:
281
282
283
        return tuple(self._output_token_ids)

    @output_token_ids.setter
284
285
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
286
287
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
288
289
        self._update_cached_all_tokens()

290
291
292
293
294
295
296
297
298
    @property
    def output_embeds(self) -> Optional[torch.Tensor]:
        return self._output_embeds

    @output_embeds.setter
    def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
        self._output_token_embeds = new_output_token_embeds
        self._update_cached_all_token_embeds()

299
300
    @property
    def output_token_ids_array(self) -> array:
301
302
303
304
305
306
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
307
308
        return self._output_token_ids

309
310
311
312
313
314
315
316
317
    @property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        return self._prompt_embeds

    @prompt_embeds.setter
    def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
        self._prompt_embeds = prompt_embeds
        self._update_cached_all_token_embeds()

318
319
320
321
322
323
324
325
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

326
327
328
329
    def append_token_id(self,
                        token_id: int,
                        logprob: float,
                        token_embed: Optional[torch.Tensor] = None) -> None:
330
        self._output_token_ids.append(token_id)
331
        self._new_appended_tokens.append(token_id)
332
        self._cached_all_token_ids.append(token_id)
333
        self._cumulative_logprob += logprob
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        if token_embed is not None:
            # Do not pass in with batch or sequence dimensions
            assert token_embed.ndim == 1
            token_embed = token_embed.detach().cpu().unsqueeze(0)
            if self._output_embeds is None:
                self._output_embeds = token_embed
            else:
                self._output_embeds = torch.cat(
                    (self._output_embeds, token_embed), dim=0)
            assert self._cached_all_token_embeds is not None
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds,
                 token_embed.to(device=self._cached_all_token_embeds.device)),
                dim=0)
348
349

    def get_len(self) -> int:
350
        return len(self._output_token_ids) + len(self._prompt_token_ids)
351

352
    def get_prompt_len(self) -> int:
353
        return len(self._prompt_token_ids)
354

355
    def get_output_len(self) -> int:
356
        return len(self._output_token_ids)
357

358
    def get_token_ids(self) -> list[int]:
359
        return self._cached_all_token_ids
360

361
362
363
    def get_token_embeddings(self) -> Optional[torch.Tensor]:
        return self._cached_all_token_embeds

364
365
    def get_prefix_token_ids(
            self, num_tokens: int
366
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
367
        """Get prefix tokens, and make the return value hashable"""
368
        prompt_length = self.get_prompt_len()
369
370
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
371
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
372
373
374
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

375
376
377
378
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

379
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
380
381
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
382
383
384
385
386
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
387

388
389
390
391
392
393
394
395
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

396
    def reset_state_for_recompute(self) -> None:
397
398
399
400
401
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
402
        self._stage = SequenceStage.PREFILL
403
        self._new_appended_tokens = []
404
405

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
406
        """Return the number of prefill tokens that are not computed."""
407
408
409
410
411
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

412
    def get_last_token_id(self) -> int:
413
414
415
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
416

417
    def get_prompt_token_ids(self) -> tuple[int, ...]:
418
419
        return self.prompt_token_ids

420
    def get_output_token_ids(self) -> tuple[int, ...]:
421
422
        return self.output_token_ids

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

438
439
440
441
    @property
    def stage(self) -> SequenceStage:
        return self._stage

442
443
    def __repr__(self) -> str:
        return (f"SequenceData("
444
                f"prompt_token_ids={self._prompt_token_ids}, "
445
446
                f"prompt_embeds.shape="
                f"{getattr(self._prompt_embeds, 'shape', None)}, "
447
448
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
449
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
450
451


Woosuk Kwon's avatar
Woosuk Kwon committed
452
class Sequence:
453
    """Stores the data, status, and block information of a sequence.
454

455
456
457
458
459
    The sequence is constructed from the
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
    or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    (for encoder-decoder) instance passed in through the `inputs`
    constructor argument.
460

461
462
    Args:
        seq_id: The ID of the sequence.
463
        inputs: The inputs of the sequence.
464
465
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
466
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
467
        lora_request: LoRA request.
468
        prompt_adapter_request: Prompt Adapter request.
469
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
470
471

    def __init__(
472
473
        self,
        seq_id: int,
474
        inputs: SingletonInputs,
475
476
477
478
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
479
480
    ) -> None:
        self.seq_id = seq_id
481
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
482
        self.block_size = block_size
483
        self.eos_token_id = eos_token_id
484
        self.lora_request = lora_request
485
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
486

487
488
489
490
        self.data = SequenceData.from_seqs(
            self.prompt_token_ids,
            prompt_embeds=self.inputs["prompt_embeds"]
            if self.inputs["type"] == "embeds" else None)
491
        self.output_logprobs: SampleLogprobs = []
492
        self.output_text = ""
493

494
        self.status = SequenceStatus.WAITING
495
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
496

497
        # These are used to keep track of delta outputs
498
        self._last_output_token_ids_offset: int = 0
499
500
        self._last_output_text_offset: int = 0

501
502
503
504
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
505
        self.tokens: Optional[list[str]] = None
506

507
508
    @property
    def n_blocks(self) -> int:
509
        return (self.get_len() + self.block_size - 1) // self.block_size
510

511
    @property
512
    def prompt(self) -> Optional[str]:
513
514
        if self.inputs["type"] == "embeds":
            return None
515
        return self.inputs.get("prompt")
516

517
    @property
518
    def prompt_token_ids(self) -> list[int]:
519
520
        if self.inputs["type"] == "embeds":
            return [0] * len(self.inputs["prompt_embeds"])
521
        return self.inputs["prompt_token_ids"]
522

523
    @property
524
    def token_type_ids(self) -> list[int]:
525
526
        if self.inputs["type"] == "embeds":
            return []
527
        return self.inputs.get("token_type_ids", [])
528

529
    @property
530
531
532
533
534
    def multi_modal_data(self) -> MultiModalKwargs:
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_kwargs"]

        return MultiModalKwargs({})
535

536
537
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
538
539
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_placeholders"]
540

541
        return {}
542

543
544
545
546
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

547
548
549
550
551
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

552
553
554
555
556
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

557
558
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
559
560
561
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
562
563
564
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
565
566
567
568
569
570
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

571
572
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
573
574
575
576
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
577
578
579
580
581
582
583
584
585
586
587
588
589

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

590
591
592
        if num_new_tokens == 0:
            return []

593
        return self.data._cached_all_token_ids[-num_new_tokens:]
594

595
    def hash_of_block(self, logical_idx: int) -> int:
596
597
        # TODO This can produce incorrect hash when block size > prompt size

598
        # Compute the number of tokens in the sequence
599
600
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
601
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
602
603
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
604

605
606
607
608
609
610
611
612
613
614
615
616
617
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
        if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
        return hash((self.prompt_adapter_id, self.lora_int_id))

618
619
620
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

621
622
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
623
        self.data.reset_state_for_recompute()
624

625
626
627
628
    def append_token_id(self,
                        token_id: int,
                        logprobs: dict[int, Logprob],
                        token_embed: Optional[torch.Tensor] = None) -> None:
629
630
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
631
632
        self.data.append_token_id(token_id, logprobs[token_id].logprob,
                                  token_embed)
633

Woosuk Kwon's avatar
Woosuk Kwon committed
634
    def get_len(self) -> int:
635
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
636

637
638
639
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

640
641
642
    def get_output_len(self) -> int:
        return self.data.get_output_len()

643
    def get_token_ids(self) -> list[int]:
644
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
645

646
    def get_prompt_token_ids(self) -> tuple[int, ...]:
647
648
        return self.data.get_prompt_token_ids()

649
    def get_last_token_id(self) -> int:
650
        return self.data.get_last_token_id()
651

652
    def get_output_token_ids(self) -> tuple[int, ...]:
653
        return self.data.get_output_token_ids()
654
655
656
657

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

658
659
660
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

661
662
663
664
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
665

666
667
668
669
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
670
671
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
672
673
674
675
676
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

677
678
679
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

680
681
682
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
683
    def __repr__(self) -> str:
684
685
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
686
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
687

Woosuk Kwon's avatar
Woosuk Kwon committed
688

689
690
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
691
692
693
694
695
696
697
698
699
700
701
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
702
class SequenceGroup:
703
704
705
706
707
708
709
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
710
        lora_request: LoRA request.
711
        pooling_params: The parameters used to generate the pooler
712
            for a pooling model.
713
        pooled_data: The extracted hidden states from a pooling model.
714
715
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
716
        trace_headers: OpenTelemetry trace headers.
717
        prompt_adapter_request: Prompt Adapter request.
718
        priority: User-defined priority of the request.
719
        draft_size: The number of speculative tokens plus one from the target
720
                    model; equal to max number of tokens a step can generate
721
                    for single-draft speculative decoding but larger than
722
                    that for multi-draft SD (currently not supported).
723
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
724

725
726
727
728
729
730
731
732
733
734
735
736
737
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 prompt_adapter_request: Optional[PromptAdapterRequest] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
738
        self.request_id = request_id
739
        self.seqs = seqs
740
        self.first_seq = seqs[0]
741
        self.arrival_time = arrival_time
742
        self.is_single_seq = len(seqs) == 1
743
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
744

745
        self.sampling_params = sampling_params
746
747
748
749
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
750
751
752
                                      time_in_queue=None,
                                      spec_token_acceptance_counts=[0] *
                                      draft_size)
753
        self.last_token_latency = 0.0
754
        self.lora_request = lora_request
755
        self.prompt_logprobs: Optional[PromptLogprobs] = None
756
        self.state = SequenceGroupState()
757
        self.pooling_params = pooling_params
758
        self.pooled_data = pooled_data
759
        self.prompt_adapter_request = prompt_adapter_request
760
        self.encoder_seq = encoder_seq
761
        self.trace_headers = trace_headers
762
        self.priority = priority
763

764
765
        self.cached_request_output = None

766
    @property
767
    def prompt(self) -> Optional[str]:
768
        return self.first_seq.prompt
769
770

    @property
771
    def prompt_token_ids(self) -> list[int]:
772
        return self.first_seq.prompt_token_ids
773

774
775
776
777
778
779
780
781
782
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
783
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
784
785
786
787
788
789
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

790
    @property
791
    def token_type_ids(self) -> Optional[list[int]]:
792
793
        return self.first_seq.token_type_ids

794
    @property
795
    def multi_modal_data(self) -> MultiModalKwargs:
796
797
798
799
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
800
        return MultiModalKwargs({})
Woosuk Kwon's avatar
Woosuk Kwon committed
801

802
803
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
804
805
806
807
808
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
809

810
811
812
813
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

814
815
816
817
818
819
820
821
822
823
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

824
825
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
826
827
        self.state.current_step = 0

828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

853
    def set_last_token_time(self, now: float) -> None:
854
        """Sets the last token time for Request level timings."""
855
856
857
858
859
        # If still in prefill phase, assertion fails.
        assert not self.is_prefill(), (
            "seq_group.set_last_token_time() should not be called "
            "if the seq_group is in prefill phase.")
        self.last_token_latency = now - self.metrics.last_token_time
860
        self.metrics.last_token_time = now
861
862
863
864
865
866
867

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
        assert not self.is_prefill(), (
            "seq_group.get_last_token_latency() should not be called "
            "if the seq_group is in prefill phase.")
        return self.last_token_latency
868

869
870
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
871
872
873
874
875
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
876
                and self.first_seq.get_output_len() == 1):
877
878
879
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
880
881
        """Sets the first scheduled time and time in queue for Request
        level timings."""
882
883
884
885
886
887
888
889
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

890
891
892
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
893
894
895
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
896

897
898
899
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
900
    ) -> list[Sequence]:
901
902
        if status is None:
            return self.seqs
903

904
905
906
907
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
908

909
910
911
912
913
914
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

915
    def get_finished_seqs(self) -> list[Sequence]:
916
917
918
919
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
920

921
922
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
923
924
925
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
926
927

    def get_num_uncomputed_tokens(self) -> int:
928
        num_uncomputed_tokens = 0
929
930
931
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
932
        return num_uncomputed_tokens
933

934
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
935
936
937
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
938
            return len(self.seqs)
939

940
941
942
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

943
        return len(self.get_seqs(status))
944

945
    def num_finished_seqs(self) -> int:
946
947
948
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
949

Woosuk Kwon's avatar
Woosuk Kwon committed
950
    def is_finished(self) -> bool:
951
952
953
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
954

955
    def is_prefill(self) -> bool:
956
        return self.first_seq.is_prefill()
957

Woosuk Kwon's avatar
Woosuk Kwon committed
958
    def __repr__(self) -> str:
959
960
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
961
                f"num_seqs={len(self.seqs)})")
962

963
964
965
966
    def uses_prompt_embeds(self) -> bool:
        """Returns True if the sequence group uses input embeds."""
        return any(seq.data.prompt_embeds is not None for seq in self.seqs)

967

968
969
970
971
972
973
974
975
976
977
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
978
    seq_data_delta: dict[int, SequenceDataDelta]
979
    request_id: str
980
    block_tables: dict[int, list[int]]
981
982
983
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
984
    computed_block_nums: Optional[list[int]] = None
985
986
987
988
989
990
991
992
993
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
994
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
995
996
997
998
999
1000
1001
1002

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
1003
1004
1005
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
1006
1007
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
1008
        lora_request: LoRA request.
1009
1010
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
1011
        state: Internal state tied to this sequence group.
1012
        multi_modal_data: Multi modal data.
1013
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
1014
        encoder_seq_data: Optional sequence data for encoder prompt
1015
                          (SequenceGroup.encoder_seq). Should be None
1016
1017
1018
1019
1020
1021
1022
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
1023
        prompt_adapter_request: Prompt Adapter request.
1024
    """
1025

1026
1027
    request_id: str
    is_prompt: bool
1028
    seq_data: dict[int, SequenceData]
1029
    sampling_params: Optional[SamplingParams]
1030
    block_tables: dict[int, list[int]]
1031
1032
1033
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
1034
    computed_block_nums: Optional[list[int]] = None
1035
1036
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
1037
    token_type_ids: Optional[list[int]] = None
1038
    multi_modal_data: Optional[MultiModalKwargs] = None
1039
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
1040
    encoder_seq_data: Optional[SequenceData] = None
1041
    cross_block_table: Optional[list[int]] = None
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
1057
            else:
1058
                self.token_chunk_size = 1
1059

1060
1061
1062
1063
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1064
    @property
1065
1066
1067
1068
1069
1070
1071
1072
1073
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1088
1089
1090
1091
1092
1093
1094
1095
1096
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1097

1098
    def finish_step(self) -> None:
1099
        assert self.state is not None
1100
1101
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1102
1103
        self.state.current_step += 1

1104

1105
1106
1107
1108
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1109
1110
1111
1112
1113
1114
1115
1116
1117
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1118
1119
    parent_seq_id: int
    output_token: int
1120
    logprobs: dict[int, Logprob]
1121
    output_embed: Optional[torch.Tensor] = None
1122
1123

    def __repr__(self) -> str:
1124
1125
        output_embed_shape = \
            self.output_embed.shape if self.output_embed is not None else None
Zhuohan Li's avatar
Zhuohan Li committed
1126
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1127
                f"output_token={self.output_token}, "
1128
                f"output_embed.shape={output_embed_shape}, "
1129
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1130

1131
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1132
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1133
            raise NotImplementedError()
1134
1135
1136
1137
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1138
1139


1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1152
1153
1154
1155
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1156
    """The model output associated with a completion sequence group."""
1157
    __metaclass__ = SequenceGroupOutput
1158
    samples: list[SequenceOutput]
1159
1160
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1161
    step_index: Optional[int] = 0
1162
1163

    def __repr__(self) -> str:
1164
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1165
1166
                f"prompt_logprobs={self.prompt_logprobs})")

1167
    def __eq__(self, other: object) -> bool:
1168
        if not isinstance(other, CompletionSequenceGroupOutput):
1169
1170
1171
1172
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1173

1174
class PoolingSequenceGroupOutput(
1175
1176
1177
1178
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1179
    """The model output associated with a pooling sequence group."""
1180
    __metaclass__ = SequenceGroupOutput
1181
1182
1183
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1184
1185

    def __repr__(self) -> str:
1186
        return f"PoolingSequenceGroupOutput(data={self.data}"
1187
1188

    def __eq__(self, other: object) -> bool:
1189
        if not isinstance(other, PoolingSequenceGroupOutput):
1190
            raise NotImplementedError()
1191
        return self.data == other.data
1192
1193


1194
1195
1196
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1197
1198
1199
1200
1201
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

1202
    tensors: dict[str, torch.Tensor]
1203

1204
1205
1206
1207
1208
1209
1210
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1211
1212
1213
1214
1215
1216
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1217
    def __setitem__(self, key: str, value: torch.Tensor):
1218
1219
        self.tensors[key] = value

1220
1221
1222
    def items(self):
        return self.tensors.items()

1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1233
1234
1235
1236
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1237
    """The output from a pooling operation in the pooling model."""
1238
    outputs: list[PoolingSequenceGroupOutput]
1239

1240
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1241
1242
        return self.outputs[idx]

1243
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1254
def get_all_seq_ids(
1255
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1256
1257
1258
1259
1260
1261
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1262
def get_all_seq_ids_and_request_ids(
1263
1264
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1265
1266
1267
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1268
1269
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1270
1271
1272
1273
1274
1275
1276
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1277
1278
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1279
1280
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1281
    the target model to the proposer model.
1282
1283
1284

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1285
1286
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1287
    hidden_states: torch.Tensor
1288
    # The sequence group metadata list. Only needed for decode step.
1289
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1290
1291
1292
1293
1294
1295
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1296
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1297
1298

    def __post_init__(self):
1299
1300
1301
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1302
1303

    @property
1304
    def seq_ids(self) -> list[int]:
1305
        return self._seq_ids
1306

1307
1308
    def update(self,
               hidden_states: torch.Tensor,
1309
               seq_group_metadata_list: list[SequenceGroupMetadata],
1310
1311
1312
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1313
        assert len(seq_group_metadata_list) == len(hidden_states)
1314
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1315
1316
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1317
1318
1319
1320
1321
1322
1323
1324
1325
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1326
    def prune(self,
1327
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1328
1329
1330
1331
1332
1333
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1334
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1335
1336
        # Only keep sequence IDs that exist in self._seq_ids
        seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
1337
        if seq_ids != self._seq_ids:
1338
            # Batch contents changed - prune removed sequences.
1339
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1340
            self.hidden_states = self.hidden_states[index]
1341
1342
1343
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1344
            self._seq_ids = seq_ids
1345

1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1364

1365
1366
1367
1368
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1369
1370
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1371
    # The sequence group metadata list.
1372
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1373
                                        SequenceGroupMetadataDelta]]
1374
    # Blocks to swap in. List of CPU -> GPU block number.
1375
    blocks_to_swap_in: list[tuple[int,
1376
                                  int]] = msgspec.field(default_factory=list)
1377
    # Blocks to swap out. List of GPU -> CPU block number.
1378
    blocks_to_swap_out: list[tuple[int,
1379
                                   int]] = msgspec.field(default_factory=list)
1380
    # Blocks to copy. Source to dest block.
1381
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1382
1383
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1384
1385
1386
1387
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1388
1389
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1390
1391
    # The number of forward steps to run.
    num_steps: int = 1
1392
1393
    # The step index for spec model input.
    spec_step_idx: Optional[int] = None
Mor Zusman's avatar
Mor Zusman committed
1394
    # Finished request ids since last step.
1395
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1396
1397
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1398
1399
    # Async callback
    async_callback: Optional[Callable] = None
1400
1401
1402
1403
1404
1405
1406

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1407
        assert first_seq_group.state is not None
1408
1409
1410
1411
1412
1413
1414
1415
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1416
        assert first_seq_group.state is not None
1417
        return first_seq_group.state.remaining_steps == 1
1418
1419
1420
1421
1422
1423

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1424
1425
1426
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1427
1428

    def clone(
1429
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1430
                                                  SequenceGroupMetadataDelta]]
1431
1432
1433
1434
1435
1436
1437
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1438
            virtual_engine=self.virtual_engine,
1439
1440
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1441
            previous_hidden_states=self.previous_hidden_states,
1442
            num_steps=self.num_steps,
1443
1444
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1445
            if self.last_sampled_token_ids is not None else None,
1446
            async_callback=self.async_callback)
1447
1448
1449
1450
1451
1452
1453
1454
1455


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1456
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1457
1458

    # seq ids to be finished
1459
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1460
1461

    # seq id to finished sequences
1462
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1499
            params = original_params.clone()
1500
1501
1502
            params.n = 1
            if params.seed is not None:
                params.seed += i
1503
            seq_group = engine._add_processed_request(
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1523
            pooled_data=seq_group.pooled_data,
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            prompt_adapter_request=seq_group.prompt_adapter_request,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1537
1538
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1539
        if self.streaming:
1540
1541
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1542
1543
1544
1545
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1546
        # when the last sequences finishes, and then return None for the
1547
        # rest of the time
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None