sequence.py 61.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sequence and its related classes."""
4
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
from abc import ABC, abstractmethod
7
from array import array
8
from collections import defaultdict
9
10
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
11
from dataclasses import dataclass, field
12
from functools import reduce
13
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
import msgspec
16
17
import torch

18
from vllm import envs
19
20
from vllm.inputs import SingletonInputs
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
21
from vllm.pooling_params import PoolingParams
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
23

24
if TYPE_CHECKING:
25
    from vllm.lora.request import LoRARequest
26
27
    from vllm.v1.worker.kv_connector_model_runner_mixin import (
        KVConnectorOutput)
28
29
30
else:
    LoRARequest = Any
    KVConnectorOutput = Any
31

32
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
33

34
35
VLLM_INVALID_TOKEN_ID = -1

36

37
def array_full(token_id: int, count: int):
38
    """[`array`][] equivalent of [numpy.full][]."""
39
40
41
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


42
43
44
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
45
46
@dataclass
class Logprob:
47
48
49
50
51
52
53
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
54
    logprob: float
55
    rank: Optional[int] = None
56
57
58
    decoded_token: Optional[str] = None


59
60
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
61
PromptLogprobs = list[Optional[dict[int, Logprob]]]
62
# {token_id -> logprob} for each sequence group.
63
SampleLogprobs = list[dict[int, Logprob]]
64

Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
class SequenceStatus(enum.IntEnum):
67
    """Status of a sequence."""
68
69
70
71
72
73
74
75
76
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
77
78
79

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
80
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
81
82
83
84
85
86
87

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
88
89
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
90
        elif status == SequenceStatus.FINISHED_IGNORED:
91
92
93
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
94
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
95
96
97
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
98

99

100
101
102
103
104
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


105
106
107
108
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

109
    Attributes:
110
111
112
113
114
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
115
116
117
118
119
120
121
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
122
123
124
125
126
127
128
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
129
130
131
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
132
133


134
135
136
137
138
139
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
140
    new_output_token_ids: list[int]
141
142
143
144
145
146
147
148
149
150
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
151
    """Data associated with a sequence."""
152
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
153
154
155
156
157
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

158
159
160
    _prompt_embeds: Optional[torch.Tensor] = None
    _output_embeds: Optional[torch.Tensor] = None

161
162
    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
163
    _prompt_token_ids_tuple: tuple[int,
164
165
166
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
167
168
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
169
    _stage: SequenceStage = SequenceStage.PREFILL
170
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
171
    _cached_all_token_embeds: Optional[torch.Tensor] = None
172
173
174

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
175
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
176

177
178
179
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

180
181
    _first_step_flag: bool = True

182
183
    @staticmethod
    def from_prompt_token_counts(
zhuwenwen's avatar
zhuwenwen committed
184
            *token_counts: tuple[int, int]) -> "SequenceData":
185
        """
186
187
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        by concatenating prompt token sequences.
188
189

        Each tuple represents one token sequence, expressed in the form
190
        `(token_id, count)`.
191
192
193
194
195
196
197
198
199
200
        """
        if len(token_counts) == 0:
            return SequenceData.from_seqs([])

        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )

        return SequenceData(prompt_token_ids_arr)
201
202
203
204
205

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
206
207
        *,
        prompt_embeds: Optional[torch.Tensor] = None,
208
    ) -> "SequenceData":
209
        """
210
211
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        from prompt and output token sequences.
212
        """
213
214
215
216
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
217
218
            return SequenceData(prompt_token_ids_arr,
                                _prompt_embeds=prompt_embeds)
219
220
221
222
223

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
224
225
                            _output_token_ids=output_token_ids_arr,
                            _prompt_embeds=prompt_embeds)
226

227
228
229
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
230
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
231
            self._prompt_token_ids)
232
        self._update_cached_all_tokens()
233
234
        if self._prompt_embeds is not None:
            self._update_cached_all_token_embeds()
235
236

    def _update_cached_all_tokens(self):
237
238
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
239
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
240
                                                     self._output_token_ids)
241

242
243
244
245
246
247
248
    def _update_cached_all_token_embeds(self):
        assert isinstance(self._prompt_embeds, torch.Tensor)
        self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
        if self._output_embeds is not None:
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds, self._output_embeds), dim=0)

249
250
    @property
    def cumulative_logprob(self) -> float:
251
        """The cumulative log probability of the output."""
252
253
        return self._cumulative_logprob

254
    @property
255
    def prompt_token_ids(self) -> tuple[int, ...]:
256
        """The token IDs of the prompt."""
257
258
259
260
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
261
        raise NotImplementedError
262

263
264
    @property
    def prompt_token_ids_array(self) -> array:
265
266
267
268
269
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
270
271
        return self._prompt_token_ids

272
    @property
273
    def output_token_ids(self) -> tuple[int, ...]:
274
        """The token IDs of the output."""
275
276
277
        return tuple(self._output_token_ids)

    @output_token_ids.setter
278
279
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
280
281
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
282
283
        self._update_cached_all_tokens()

284
285
286
287
288
289
290
291
292
    @property
    def output_embeds(self) -> Optional[torch.Tensor]:
        return self._output_embeds

    @output_embeds.setter
    def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
        self._output_token_embeds = new_output_token_embeds
        self._update_cached_all_token_embeds()

293
294
    @property
    def output_token_ids_array(self) -> array:
295
296
297
298
299
300
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
301
302
        return self._output_token_ids

303
304
305
306
307
308
309
310
311
    @property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        return self._prompt_embeds

    @prompt_embeds.setter
    def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
        self._prompt_embeds = prompt_embeds
        self._update_cached_all_token_embeds()

312
313
314
315
316
317
318
319
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

320
321
322
323
    def append_token_id(self,
                        token_id: int,
                        logprob: float,
                        token_embed: Optional[torch.Tensor] = None) -> None:
324
        self._output_token_ids.append(token_id)
325
        self._new_appended_tokens.append(token_id)
326
        self._cached_all_token_ids.append(token_id)
327
        self._cumulative_logprob += logprob
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        if token_embed is not None:
            # Do not pass in with batch or sequence dimensions
            assert token_embed.ndim == 1
            token_embed = token_embed.detach().cpu().unsqueeze(0)
            if self._output_embeds is None:
                self._output_embeds = token_embed
            else:
                self._output_embeds = torch.cat(
                    (self._output_embeds, token_embed), dim=0)
            assert self._cached_all_token_embeds is not None
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds,
                 token_embed.to(device=self._cached_all_token_embeds.device)),
                dim=0)
342
343

    def get_len(self) -> int:
344
        return len(self._output_token_ids) + len(self._prompt_token_ids)
345

346
    def get_prompt_len(self) -> int:
347
        return len(self._prompt_token_ids)
348

349
    def get_output_len(self) -> int:
350
        return len(self._output_token_ids)
351

352
    def get_token_ids(self) -> list[int]:
353
        return self._cached_all_token_ids
354

355
356
357
    def get_token_embeddings(self) -> Optional[torch.Tensor]:
        return self._cached_all_token_embeds

358
359
    def get_prefix_token_ids(
            self, num_tokens: int
360
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
361
        """Get prefix tokens, and make the return value hashable"""
362
        prompt_length = self.get_prompt_len()
363
364
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
365
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
366
367
368
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

369
370
371
372
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

373
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
374
375
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
376
377
378
379
380
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
381

382
383
384
385
386
387
388
389
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

390
    def reset_state_for_recompute(self) -> None:
391
392
393
394
395
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
396
        self._stage = SequenceStage.PREFILL
397
        self._new_appended_tokens = []
398
399

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
400
        """Return the number of prefill tokens that are not computed."""
401
402
403
404
405
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

406
    def get_last_token_id(self) -> int:
407
408
409
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
410

411
    def get_prompt_token_ids(self) -> tuple[int, ...]:
412
413
        return self.prompt_token_ids

414
    def get_output_token_ids(self) -> tuple[int, ...]:
415
416
        return self.output_token_ids

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

432
433
434
    @property
    def stage(self) -> SequenceStage:
        return self._stage
435
436
437
438
439
440
    
    def get_first_step_flag(self):
        return self._first_step_flag
    
    def set_first_step_flag(self, flag: bool):
        self._first_step_flag = flag
441

442
443
    def __repr__(self) -> str:
        return (f"SequenceData("
444
                f"prompt_token_ids={self._prompt_token_ids}, "
445
446
                f"prompt_embeds.shape="
                f"{getattr(self._prompt_embeds, 'shape', None)}, "
447
448
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
449
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
450
451


Woosuk Kwon's avatar
Woosuk Kwon committed
452
class Sequence:
453
454
    """Stores the data, status, and block information of a sequence.

455
456
457
458
459
    The sequence is constructed from the
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
    or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    (for encoder-decoder) instance passed in through the `inputs`
    constructor argument.
460

461
462
    Args:
        seq_id: The ID of the sequence.
463
        inputs: The inputs of the sequence.
464
465
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
466
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
467
        lora_request: LoRA request.
468
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
469
470

    def __init__(
471
472
        self,
        seq_id: int,
473
        inputs: SingletonInputs,
474
475
476
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
477
478
    ) -> None:
        self.seq_id = seq_id
479
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
480
        self.block_size = block_size
481
        self.eos_token_id = eos_token_id
482
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
483

484
485
486
487
        self.data = SequenceData.from_seqs(
            self.prompt_token_ids,
            prompt_embeds=self.inputs["prompt_embeds"]
            if self.inputs["type"] == "embeds" else None)
488
        self.output_logprobs: SampleLogprobs = []
489
        self.output_text = ""
490

491
        self.status = SequenceStatus.WAITING
492
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
493

494
        # These are used to keep track of delta outputs
495
        self._last_output_token_ids_offset: int = 0
496
497
        self._last_output_text_offset: int = 0

498
499
500
501
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
502
        self.tokens: Optional[list[str]] = None
503

504
505
    @property
    def n_blocks(self) -> int:
506
        return (self.get_len() + self.block_size - 1) // self.block_size
507

508
    @property
509
    def prompt(self) -> Optional[str]:
510
511
        if self.inputs["type"] == "embeds":
            return None
512
        return self.inputs.get("prompt")
513

514
    @property
515
    def prompt_token_ids(self) -> list[int]:
516
517
        if self.inputs["type"] == "embeds":
            return [0] * len(self.inputs["prompt_embeds"])
518
        return self.inputs["prompt_token_ids"]
519

520
    @property
521
522
    def multi_modal_data(self) -> MultiModalKwargs:
        if self.inputs["type"] == "multimodal":
523
            return self.inputs["mm_kwargs"].get_data()
524

525
        return MultiModalKwargs()
526

527
528
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
529
530
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_placeholders"]
531

532
        return {}
533

534
535
536
537
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

538
539
540
541
542
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

543
544
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
545
546
547
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
548
549
550
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
551
552
553
554
555
556
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

557
558
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
559
560
561
562
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
563
564
565
566
567
568
569
570
571
572
573
574
575

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

576
577
578
        if num_new_tokens == 0:
            return []

579
        return self.data._cached_all_token_ids[-num_new_tokens:]
580

581
    def hash_of_block(self, logical_idx: int) -> int:
582
583
        # TODO This can produce incorrect hash when block size > prompt size

584
        # Compute the number of tokens in the sequence
585
586
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
587
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
588
589
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
590

591
592
593
594
595
596
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
597
        if self.lora_int_id == 0:
598
599
600
601
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
602
        return hash(self.lora_int_id)
603

604
605
606
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

607
608
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
609
        self.data.reset_state_for_recompute()
610

611
612
613
614
    def append_token_id(self,
                        token_id: int,
                        logprobs: dict[int, Logprob],
                        token_embed: Optional[torch.Tensor] = None) -> None:
615
616
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
617
618
        self.data.append_token_id(token_id, logprobs[token_id].logprob,
                                  token_embed)
619

Woosuk Kwon's avatar
Woosuk Kwon committed
620
    def get_len(self) -> int:
621
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
622

623
624
625
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

626
627
628
    def get_output_len(self) -> int:
        return self.data.get_output_len()

629
    def get_token_ids(self) -> list[int]:
630
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
631

632
    def get_prompt_token_ids(self) -> tuple[int, ...]:
633
634
        return self.data.get_prompt_token_ids()

635
    def get_last_token_id(self) -> int:
636
        return self.data.get_last_token_id()
637

638
    def get_output_token_ids(self) -> tuple[int, ...]:
639
        return self.data.get_output_token_ids()
640
641
642
643

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

644
645
646
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

647
648
649
650
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
651

652
653
654
655
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
656
657
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
658
659
660
661
662
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

663
664
665
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

666
667
668
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
669
    def __repr__(self) -> str:
670
671
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
672
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
673

Woosuk Kwon's avatar
Woosuk Kwon committed
674

675
676
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
677
678
679
680
681
682
683
684
685
686
687
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
688
class SequenceGroup:
689
690
691
692
693
694
695
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
696
        lora_request: LoRA request.
697
        pooling_params: The parameters used to generate the pooler
698
            for a pooling model.
699
        pooled_data: The extracted hidden states from a pooling model.
700
701
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
702
        trace_headers: OpenTelemetry trace headers.
703
        priority: User-defined priority of the request.
704
        draft_size: The number of speculative tokens plus one from the target
705
                    model; equal to max number of tokens a step can generate
706
                    for single-draft speculative decoding but larger than
707
                    that for multi-draft SD (currently not supported).
708
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
709

710
711
712
713
714
715
716
717
718
719
720
721
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
722
        self.request_id = request_id
723
        self.seqs = seqs
724
        self.first_seq = seqs[0]
725
        self.arrival_time = arrival_time
726
        self.is_single_seq = len(seqs) == 1
727
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
728

729
        self.sampling_params = sampling_params
730
731
732
733
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
734
                                      time_in_queue=None)
735
        self.last_token_latency = 0.0
736
        self.lora_request = lora_request
737
        self.prompt_logprobs: Optional[PromptLogprobs] = None
738
        self.state = SequenceGroupState()
739
        self.pooling_params = pooling_params
740
        self.pooled_data = pooled_data
741
        self.encoder_seq = encoder_seq
742
        self.trace_headers = trace_headers
743
        self.priority = priority
744

745
746
        self.cached_request_output = None

747
    @property
748
    def prompt(self) -> Optional[str]:
749
        return self.first_seq.prompt
750
751

    @property
752
    def prompt_token_ids(self) -> list[int]:
753
        return self.first_seq.prompt_token_ids
754

755
756
757
758
759
760
761
762
763
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
764
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
765
766
767
768
769
770
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

771
    @property
772
    def multi_modal_data(self) -> MultiModalKwargs:
773
774
775
776
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
777
        return MultiModalKwargs()
Woosuk Kwon's avatar
Woosuk Kwon committed
778

779
780
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
781
782
783
784
785
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
786

787
788
789
790
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

791
    def set_last_token_time(self, now: float) -> None:
792
        """Sets the last token time for Request level timings."""
793
794
795
796
797
        if not envs.VLLM_ZERO_OVERHEAD:
            # If still in prefill phase, assertion fails.
            assert not self.is_prefill(), (
                "seq_group.set_last_token_time() should not be called "
                "if the seq_group is in prefill phase.")
798
        self.last_token_latency = now - self.metrics.last_token_time
799
        self.metrics.last_token_time = now
800
801
802

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
803
804
805
806
        if not envs.VLLM_ZERO_OVERHEAD:
            assert not self.is_prefill(), (
                "seq_group.get_last_token_latency() should not be called "
                "if the seq_group is in prefill phase.")
807
        return self.last_token_latency
808

809
810
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
811
812
813
814
815
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
816
                and self.first_seq.get_output_len() == 1):
817
818
819
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
820
821
        """Sets the first scheduled time and time in queue for Request
        level timings."""
822
823
824
825
826
827
828
829
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

830
831
832
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
833
834
835
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
836

837
838
839
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
840
    ) -> list[Sequence]:
841
842
        if status is None:
            return self.seqs
843

844
845
846
847
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
848

849
850
851
852
853
854
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

855
    def get_finished_seqs(self) -> list[Sequence]:
856
857
858
859
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
860

861
862
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
863
864
865
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
866
867

    def get_num_uncomputed_tokens(self) -> int:
868
        num_uncomputed_tokens = 0
869
870
871
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
872
        return num_uncomputed_tokens
873

874
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
875
876
877
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
878
            return len(self.seqs)
879

880
881
882
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

883
        return len(self.get_seqs(status))
884

885
    def num_finished_seqs(self) -> int:
886
887
888
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
889

Woosuk Kwon's avatar
Woosuk Kwon committed
890
    def is_finished(self) -> bool:
891
892
893
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
894

895
    def is_prefill(self) -> bool:
896
        return self.first_seq.is_prefill()
897

Woosuk Kwon's avatar
Woosuk Kwon committed
898
    def __repr__(self) -> str:
899
900
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
901
                f"num_seqs={len(self.seqs)})")
902

903
904
905
906
    def uses_prompt_embeds(self) -> bool:
        """Returns True if the sequence group uses input embeds."""
        return any(seq.data.prompt_embeds is not None for seq in self.seqs)

907

908
909
910
911
912
913
914
915
916
917
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
918
    seq_data_delta: dict[int, SequenceDataDelta]
919
    request_id: str
920
    block_tables: dict[int, list[int]]
921
922
923
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
924
    computed_block_nums: Optional[list[int]] = None
925
926
927
928
929
930
931
932
933
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
934
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
935

936
    Attributes:
937
938
939
940
941
942
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
943
944
945
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
946
        pooling_params: Pooling parameters.
947
        lora_request: LoRA request.
948
949
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
950
        state: Internal state tied to this sequence group.
951
        token_type_ids: Token type IDs.
952
        multi_modal_data: Multi modal data.
953
        multi_modal_placeholders: Multi modal placeholders.
954
        encoder_seq_data: Optional sequence data for encoder prompt
955
                          (SequenceGroup.encoder_seq). Should be None
956
957
958
959
960
961
962
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
963
    """
964

965
966
    request_id: str
    is_prompt: bool
967
    seq_data: dict[int, SequenceData]
968
    sampling_params: Optional[SamplingParams]
969
    block_tables: dict[int, list[int]]
970
971
972
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
973
    computed_block_nums: Optional[list[int]] = None
974
975
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
976
    multi_modal_data: Optional[MultiModalKwargs] = None
977
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
978
    encoder_seq_data: Optional[SequenceData] = None
979
    cross_block_table: Optional[list[int]] = None
980
981
982
983
984
985
986
987
988
989
990
991
992
993
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
994
            else:
995
                self.token_chunk_size = 1
996

997
998
999
1000
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1015
1016
1017
1018
1019
1020
1021
1022
1023
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1024

1025
    def finish_step(self) -> None:
1026
        assert self.state is not None
1027
1028
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1029
1030
        self.state.current_step += 1

1031

1032
1033
1034
1035
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1036
1037
    """The model output associated with a sequence.

1038
    Attributes:
1039
1040
1041
1042
1043
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
1044
        output_embed: Optional output embedding tensor.
1045
    """
1046
1047
    parent_seq_id: int
    output_token: int
1048
    logprobs: dict[int, Logprob]
1049
    output_embed: Optional[torch.Tensor] = None
1050
1051

    def __repr__(self) -> str:
1052
1053
        output_embed_shape = \
            self.output_embed.shape if self.output_embed is not None else None
Zhuohan Li's avatar
Zhuohan Li committed
1054
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1055
                f"output_token={self.output_token}, "
1056
                f"output_embed.shape={output_embed_shape}, "
1057
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1058

1059
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1060
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1061
            raise NotImplementedError()
1062
1063
1064
1065
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1066
1067


1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1080
1081
1082
1083
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1084
    """The model output associated with a completion sequence group."""
1085
    __metaclass__ = SequenceGroupOutput
1086
    samples: list[SequenceOutput]
1087
1088
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1089
    step_index: Optional[int] = 0
1090
1091

    def __repr__(self) -> str:
1092
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1093
1094
                f"prompt_logprobs={self.prompt_logprobs})")

1095
    def __eq__(self, other: object) -> bool:
1096
        if not isinstance(other, CompletionSequenceGroupOutput):
1097
1098
1099
1100
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1101

1102
class PoolingSequenceGroupOutput(
1103
1104
1105
1106
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1107
    """The model output associated with a pooling sequence group."""
1108
    __metaclass__ = SequenceGroupOutput
1109
1110
1111
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1112

1113
1114
1115
1116
    def get_data_nbytes(self) -> int:
        data: torch.Tensor = self.data
        return data.nbytes

1117
    def __repr__(self) -> str:
1118
        return f"PoolingSequenceGroupOutput(data={self.data}"
1119
1120

    def __eq__(self, other: object) -> bool:
1121
        if not isinstance(other, PoolingSequenceGroupOutput):
1122
            raise NotImplementedError()
1123
        return self.data == other.data
1124
1125


1126
1127
1128
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1129
1130
1131
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
1132
    
1133
    Each stage also needs to handle its own kv_connector_output.
1134
1135
    """

1136
    tensors: dict[str, torch.Tensor]
1137
    kv_connector_output: Optional[KVConnectorOutput]
1138

1139
1140
1141
1142
1143
1144
1145
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1146
1147
1148
1149
1150
1151
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1152
    def __setitem__(self, key: str, value: torch.Tensor):
1153
1154
        self.tensors[key] = value

1155
1156
1157
    def items(self):
        return self.tensors.items()

1158
1159
1160
1161
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
1162
1163
1164
1165
1166
1167
1168
        if not isinstance(other, self.__class__):
            return False
        if self.tensors.keys() != other.tensors.keys():
            return False
        return all(
            torch.equal(self.tensors[k], other.tensors[k])
            for k in self.tensors)
1169
1170
1171
1172
1173

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1174
1175
1176
1177
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1178
    """The output from a pooling operation in the pooling model."""
1179
    outputs: list[PoolingSequenceGroupOutput]
1180

1181
1182
1183
    def get_data_nbytes(self) -> int:
        return sum(o.get_data_nbytes() for o in self.outputs)

1184
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1185
1186
        return self.outputs[idx]

1187
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1198
def get_all_seq_ids(
1199
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1200
1201
1202
1203
1204
1205
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1206
def get_all_seq_ids_and_request_ids(
1207
1208
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1209
1210
1211
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1212
1213
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1214
1215
1216
1217
1218
1219
1220
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1221
1222
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1223
1224
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1225
    the target model to the proposer model.
1226
1227
1228

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1229
1230
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1231
    hidden_states: torch.Tensor
1232
    # The sequence group metadata list. Only needed for decode step.
1233
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1234
1235
1236
1237
1238
1239
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1240
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1241
1242

    def __post_init__(self):
1243
1244
1245
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1246
1247

    @property
1248
    def seq_ids(self) -> list[int]:
1249
        return self._seq_ids
1250

1251
1252
    def update(self,
               hidden_states: torch.Tensor,
1253
               seq_group_metadata_list: list[SequenceGroupMetadata],
1254
1255
1256
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1257
        assert len(seq_group_metadata_list) == len(hidden_states)
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        # self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        # self.hidden_states = torch.cat([self.hidden_states, hidden_states])

        # if self.second_last_token_hidden_states is not None:
        #     # Adding dummy hidden_states to this to maintain same shape
        #     self.second_last_token_hidden_states = torch.cat([
        #         self.second_last_token_hidden_states,
        #         torch.zeros_like(hidden_states)
        #         if second_last_token_hidden_states is None else
        #         second_last_token_hidden_states
        #     ])
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        diff_seq_ids = [item for item in self._seq_ids if item not in seq_ids]
        index = [self._seq_ids.index(seq_id) for seq_id in diff_seq_ids]
        self._seq_ids = diff_seq_ids
        self.hidden_states = self.hidden_states[index]
1274
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])
1275
        
1276
1277
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
1278
            self.second_last_token_hidden_states = self.second_last_token_hidden_states[index]
1279
1280
1281
1282
1283
1284
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])
1285
1286
        self._seq_ids.extend(seq_ids)
        
1287

1288
    def prune(self,
1289
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1290
1291
1292
1293
1294
1295
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1296
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1297
1298
        # Only keep sequence IDs that exist in self._seq_ids
        seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
1299
        if seq_ids != self._seq_ids:
1300
            # Batch contents changed - prune removed sequences.
1301
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1302
            self.hidden_states = self.hidden_states[index]
1303
1304
1305
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1306
            self._seq_ids = seq_ids
1307

1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1326

1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
class Logits(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
    """Logits corresponding to in-progress sequences.
    Used in speculative decoding to pass lm_head logits from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the logits tensor"""
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
    logits: torch.Tensor
    # The sequence group metadata list. Only needed for decode step.
zhuwenwen's avatar
zhuwenwen committed
1339
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1340

zhuwenwen's avatar
zhuwenwen committed
1341
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1342
1343
1344
1345
1346
1347
1348

    def __post_init__(self):
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.logits)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)

    @property
zhuwenwen's avatar
zhuwenwen committed
1349
    def seq_ids(self) -> list[int]:
1350
1351
1352
1353
        return self._seq_ids
    
    def update(self,
               logits: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
1354
               seq_group_metadata_list: list[SequenceGroupMetadata]):
1355
1356
1357
1358
1359
1360
1361
        """Update hidden states from target model invocation. Only used for
        decode steps"""
        assert len(seq_group_metadata_list) == len(logits)
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.logits = torch.cat([self.logits, logits])

    def prune(self,
zhuwenwen's avatar
zhuwenwen committed
1362
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self._seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
            self.logits = self.logits[index]
            self._seq_ids = seq_ids


1377
1378
1379
1380
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1381
1382
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1383
    # The sequence group metadata list.
1384
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1385
                                        SequenceGroupMetadataDelta]]
1386
    # Blocks to swap in. List of CPU -> GPU block number.
1387
    blocks_to_swap_in: list[tuple[int,
1388
                                  int]] = msgspec.field(default_factory=list)
1389
    # Blocks to swap out. List of GPU -> CPU block number.
1390
    blocks_to_swap_out: list[tuple[int,
1391
                                   int]] = msgspec.field(default_factory=list)
1392
    # Blocks to copy. Source to dest block.
1393
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1394
1395
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1396
1397
1398
1399
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1400
1401
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1402
1403
    # Optional logits from prior step.
    previous_logits: Optional[Logits] = None
1404
1405
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1406
    # Finished request ids since last step.
1407
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1408
1409
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1410
1411
    # Async callback
    async_callback: Optional[Callable] = None
1412

1413
1414
1415
1416
1417
1418
    # Optional tree attention mask from draft model.
    tree_attn_masks: Optional[torch.Tensor] = None

    # Optional tree position ids from draft model.
    tree_position_ids: Optional[torch.Tensor] = None

1419
1420
1421
    # Optional slot mapping of kvcache that pending to be moved generated from draft model.
    kvcache_slot_to_be_moved: Optional[torch.Tensor] = None

1422
1423
1424
1425
1426
1427
    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1428
        assert first_seq_group.state is not None
1429
        return first_seq_group.state.remaining_steps == 1
1430
1431
1432
1433
1434
1435

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1436
1437
1438
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1439
1440

    def clone(
1441
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1442
                                                  SequenceGroupMetadataDelta]]
1443
1444
1445
1446
1447
1448
1449
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1450
            virtual_engine=self.virtual_engine,
1451
1452
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1453
            previous_hidden_states=self.previous_hidden_states,
1454
            previous_logits=self.previous_logits,
1455
            num_steps=self.num_steps,
1456
1457
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1458
            if self.last_sampled_token_ids is not None else None,
1459
1460
            async_callback=self.async_callback,
            tree_attn_masks=self.tree_attn_masks,
1461
1462
            tree_position_ids=self.tree_position_ids,
            kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)
1463
1464
1465
1466
1467
1468
1469
1470
1471


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1472
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1473
1474

    # seq ids to be finished
1475
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1476
1477

    # seq id to finished sequences
1478
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1515
            params = original_params.clone()
1516
1517
1518
            params.n = 1
            if params.seed is not None:
                params.seed += i
1519
            seq_group = engine._add_processed_request(
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1539
            pooled_data=seq_group.pooled_data,
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1552
1553
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1554
        if self.streaming:
1555
1556
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1557
1558
1559
1560
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1561
        # when the last sequences finishes, and then return None for the
1562
        # rest of the time
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None