sequence.py 63 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sequence and its related classes."""
4
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
from abc import ABC, abstractmethod
7
from array import array
8
from collections import defaultdict
9
10
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
11
from dataclasses import dataclass, field
12
from functools import reduce
13
from typing import Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
import msgspec
16
17
import torch

18
from vllm import envs
19
from vllm.inputs import SingletonInputs
20
from vllm.lora.request import LoRARequest
21
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
22
from vllm.pooling_params import PoolingParams
23
from vllm.sampling_params import RequestOutputKind, SamplingParams
24

25
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
26

27
28
VLLM_INVALID_TOKEN_ID = -1

29

30
def array_full(token_id: int, count: int):
31
    """[`array`][] equivalent of [numpy.full][]."""
32
33
34
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


35
36
37
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
38
39
@dataclass
class Logprob:
40
41
42
43
44
45
46
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
47
    logprob: float
48
    rank: Optional[int] = None
49
50
51
    decoded_token: Optional[str] = None


52
53
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
54
PromptLogprobs = list[Optional[dict[int, Logprob]]]
55
# {token_id -> logprob} for each sequence group.
56
SampleLogprobs = list[dict[int, Logprob]]
57

Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
class SequenceStatus(enum.IntEnum):
60
    """Status of a sequence."""
61
62
63
64
65
66
67
68
69
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
73
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
74
75
76
77
78
79
80

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
81
82
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
83
        elif status == SequenceStatus.FINISHED_IGNORED:
84
85
86
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
87
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
88
89
90
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
91

92

93
94
95
96
97
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


98
99
100
101
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

102
    Attributes:
103
104
105
106
107
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
108
109
110
111
112
113
114
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
115
116
117
118
119
120
121
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
122
123
124
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
125
126


127
128
129
130
131
132
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
133
    new_output_token_ids: list[int]
134
135
136
137
138
139
140
141
142
143
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
144
145
146
147
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
148
149
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
150
151
152
153
154
155

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
156
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
157
158
159
160
161
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

162
163
164
    _prompt_embeds: Optional[torch.Tensor] = None
    _output_embeds: Optional[torch.Tensor] = None

165
166
    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
167
    _prompt_token_ids_tuple: tuple[int,
168
169
170
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
171
172
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
173
    _stage: SequenceStage = SequenceStage.PREFILL
174
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
175
    _cached_all_token_embeds: Optional[torch.Tensor] = None
176
177
178

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
179
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
180

181
182
183
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

184
185
    _first_step_flag: bool = True

186
187
    @staticmethod
    def from_prompt_token_counts(
zhuwenwen's avatar
zhuwenwen committed
188
            *token_counts: tuple[int, int]) -> "SequenceData":
189
        """
190
191
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        by concatenating prompt token sequences.
192
193

        Each tuple represents one token sequence, expressed in the form
194
        `(token_id, count)`.
195
196
197
198
199
200
201
202
203
204
        """
        if len(token_counts) == 0:
            return SequenceData.from_seqs([])

        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )

        return SequenceData(prompt_token_ids_arr)
205
206
207
208
209

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
210
211
        *,
        prompt_embeds: Optional[torch.Tensor] = None,
212
    ) -> "SequenceData":
213
        """
214
215
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        from prompt and output token sequences.
216
        """
217
218
219
220
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
221
222
            return SequenceData(prompt_token_ids_arr,
                                _prompt_embeds=prompt_embeds)
223
224
225
226
227

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
228
229
                            _output_token_ids=output_token_ids_arr,
                            _prompt_embeds=prompt_embeds)
230

231
232
233
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
234
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
235
            self._prompt_token_ids)
236
        self._update_cached_all_tokens()
237
238
        if self._prompt_embeds is not None:
            self._update_cached_all_token_embeds()
239
240

    def _update_cached_all_tokens(self):
241
242
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
243
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
244
                                                     self._output_token_ids)
245

246
247
248
249
250
251
252
    def _update_cached_all_token_embeds(self):
        assert isinstance(self._prompt_embeds, torch.Tensor)
        self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
        if self._output_embeds is not None:
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds, self._output_embeds), dim=0)

253
254
255
256
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

257
    @property
258
    def prompt_token_ids(self) -> tuple[int, ...]:
259
260
261
262
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
263
        raise NotImplementedError
264

265
266
    @property
    def prompt_token_ids_array(self) -> array:
267
268
269
270
271
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
272
273
        return self._prompt_token_ids

274
    @property
275
    def output_token_ids(self) -> tuple[int, ...]:
276
277
278
        return tuple(self._output_token_ids)

    @output_token_ids.setter
279
280
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
281
282
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
283
284
        self._update_cached_all_tokens()

285
286
287
288
289
290
291
292
293
    @property
    def output_embeds(self) -> Optional[torch.Tensor]:
        return self._output_embeds

    @output_embeds.setter
    def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
        self._output_token_embeds = new_output_token_embeds
        self._update_cached_all_token_embeds()

294
295
    @property
    def output_token_ids_array(self) -> array:
296
297
298
299
300
301
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
302
303
        return self._output_token_ids

304
305
306
307
308
309
310
311
312
    @property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        return self._prompt_embeds

    @prompt_embeds.setter
    def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
        self._prompt_embeds = prompt_embeds
        self._update_cached_all_token_embeds()

313
314
315
316
317
318
319
320
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

321
322
323
324
    def append_token_id(self,
                        token_id: int,
                        logprob: float,
                        token_embed: Optional[torch.Tensor] = None) -> None:
325
        self._output_token_ids.append(token_id)
326
        self._new_appended_tokens.append(token_id)
327
        self._cached_all_token_ids.append(token_id)
328
        self._cumulative_logprob += logprob
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        if token_embed is not None:
            # Do not pass in with batch or sequence dimensions
            assert token_embed.ndim == 1
            token_embed = token_embed.detach().cpu().unsqueeze(0)
            if self._output_embeds is None:
                self._output_embeds = token_embed
            else:
                self._output_embeds = torch.cat(
                    (self._output_embeds, token_embed), dim=0)
            assert self._cached_all_token_embeds is not None
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds,
                 token_embed.to(device=self._cached_all_token_embeds.device)),
                dim=0)
343
344

    def get_len(self) -> int:
345
        return len(self._output_token_ids) + len(self._prompt_token_ids)
346

347
    def get_prompt_len(self) -> int:
348
        return len(self._prompt_token_ids)
349

350
    def get_output_len(self) -> int:
351
        return len(self._output_token_ids)
352

353
    def get_token_ids(self) -> list[int]:
354
        return self._cached_all_token_ids
355

356
357
358
    def get_token_embeddings(self) -> Optional[torch.Tensor]:
        return self._cached_all_token_embeds

359
360
    def get_prefix_token_ids(
            self, num_tokens: int
361
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
362
        """Get prefix tokens, and make the return value hashable"""
363
        prompt_length = self.get_prompt_len()
364
365
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
366
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
367
368
369
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

370
371
372
373
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

374
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
375
376
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
377
378
379
380
381
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
382

383
384
385
386
387
388
389
390
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

391
    def reset_state_for_recompute(self) -> None:
392
393
394
395
396
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
397
        self._stage = SequenceStage.PREFILL
398
        self._new_appended_tokens = []
399
400

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
401
        """Return the number of prefill tokens that are not computed."""
402
403
404
405
406
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

407
    def get_last_token_id(self) -> int:
408
409
410
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
411

412
    def get_prompt_token_ids(self) -> tuple[int, ...]:
413
414
        return self.prompt_token_ids

415
    def get_output_token_ids(self) -> tuple[int, ...]:
416
417
        return self.output_token_ids

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

433
434
435
    @property
    def stage(self) -> SequenceStage:
        return self._stage
436
437
438
439
440
441
    
    def get_first_step_flag(self):
        return self._first_step_flag
    
    def set_first_step_flag(self, flag: bool):
        self._first_step_flag = flag
442

443
444
    def __repr__(self) -> str:
        return (f"SequenceData("
445
                f"prompt_token_ids={self._prompt_token_ids}, "
446
447
                f"prompt_embeds.shape="
                f"{getattr(self._prompt_embeds, 'shape', None)}, "
448
449
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
450
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
451
452


Woosuk Kwon's avatar
Woosuk Kwon committed
453
class Sequence:
454
455
    """Stores the data, status, and block information of a sequence.

456
457
458
459
460
    The sequence is constructed from the
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
    or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    (for encoder-decoder) instance passed in through the `inputs`
    constructor argument.
461

462
463
    Args:
        seq_id: The ID of the sequence.
464
        inputs: The inputs of the sequence.
465
466
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
467
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
468
        lora_request: LoRA request.
469
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
470
471

    def __init__(
472
473
        self,
        seq_id: int,
474
        inputs: SingletonInputs,
475
476
477
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
478
479
    ) -> None:
        self.seq_id = seq_id
480
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
481
        self.block_size = block_size
482
        self.eos_token_id = eos_token_id
483
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
484

485
486
487
488
        self.data = SequenceData.from_seqs(
            self.prompt_token_ids,
            prompt_embeds=self.inputs["prompt_embeds"]
            if self.inputs["type"] == "embeds" else None)
489
        self.output_logprobs: SampleLogprobs = []
490
        self.output_text = ""
491

492
        self.status = SequenceStatus.WAITING
493
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
494

495
        # These are used to keep track of delta outputs
496
        self._last_output_token_ids_offset: int = 0
497
498
        self._last_output_text_offset: int = 0

499
500
501
502
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
503
        self.tokens: Optional[list[str]] = None
504

505
506
    @property
    def n_blocks(self) -> int:
507
        return (self.get_len() + self.block_size - 1) // self.block_size
508

509
    @property
510
    def prompt(self) -> Optional[str]:
511
512
        if self.inputs["type"] == "embeds":
            return None
513
        return self.inputs.get("prompt")
514

515
    @property
516
    def prompt_token_ids(self) -> list[int]:
517
518
        if self.inputs["type"] == "embeds":
            return [0] * len(self.inputs["prompt_embeds"])
519
        return self.inputs["prompt_token_ids"]
520

521
    @property
522
    def token_type_ids(self) -> list[int]:
523
524
        if self.inputs["type"] == "embeds":
            return []
525
        return self.inputs.get("token_type_ids", [])
526
527

    @property
528
529
530
531
532
    def multi_modal_data(self) -> MultiModalKwargs:
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_kwargs"]

        return MultiModalKwargs({})
533

534
535
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
536
537
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_placeholders"]
538

539
        return {}
540

541
542
543
544
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

545
546
547
548
549
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

550
551
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
552
553
554
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
555
556
557
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
558
559
560
561
562
563
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

564
565
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
566
567
568
569
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
570
571
572
573
574
575
576
577
578
579
580
581
582

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

583
584
585
        if num_new_tokens == 0:
            return []

586
        return self.data._cached_all_token_ids[-num_new_tokens:]
587

588
    def hash_of_block(self, logical_idx: int) -> int:
589
590
        # TODO This can produce incorrect hash when block size > prompt size

591
        # Compute the number of tokens in the sequence
592
593
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
594
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
595
596
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
597

598
599
600
601
602
603
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
604
        if self.lora_int_id == 0:
605
606
607
608
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
609
        return hash(self.lora_int_id)
610

611
612
613
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

614
615
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
616
        self.data.reset_state_for_recompute()
617

618
619
620
621
    def append_token_id(self,
                        token_id: int,
                        logprobs: dict[int, Logprob],
                        token_embed: Optional[torch.Tensor] = None) -> None:
622
623
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
624
625
        self.data.append_token_id(token_id, logprobs[token_id].logprob,
                                  token_embed)
626

Woosuk Kwon's avatar
Woosuk Kwon committed
627
    def get_len(self) -> int:
628
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
629

630
631
632
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

633
634
635
    def get_output_len(self) -> int:
        return self.data.get_output_len()

636
    def get_token_ids(self) -> list[int]:
637
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
638

639
    def get_prompt_token_ids(self) -> tuple[int, ...]:
640
641
        return self.data.get_prompt_token_ids()

642
    def get_last_token_id(self) -> int:
643
        return self.data.get_last_token_id()
644

645
    def get_output_token_ids(self) -> tuple[int, ...]:
646
        return self.data.get_output_token_ids()
647
648
649
650

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

651
652
653
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

654
655
656
657
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
658

659
660
661
662
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
663
664
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
665
666
667
668
669
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

670
671
672
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

673
674
675
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
676
    def __repr__(self) -> str:
677
678
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
679
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
680

Woosuk Kwon's avatar
Woosuk Kwon committed
681

682
683
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
684
685
686
687
688
689
690
691
692
693
694
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
695
class SequenceGroup:
696
697
698
699
700
701
702
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
703
        lora_request: LoRA request.
704
        pooling_params: The parameters used to generate the pooler
705
            for a pooling model.
706
        pooled_data: The extracted hidden states from a pooling model.
707
708
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
709
        trace_headers: OpenTelemetry trace headers.
710
        priority: User-defined priority of the request.
711
        draft_size: The number of speculative tokens plus one from the target
712
                    model; equal to max number of tokens a step can generate
713
                    for single-draft speculative decoding but larger than
714
                    that for multi-draft SD (currently not supported).
715
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
716

717
718
719
720
721
722
723
724
725
726
727
728
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
729
        self.request_id = request_id
730
        self.seqs = seqs
731
        self.first_seq = seqs[0]
732
        self.arrival_time = arrival_time
733
        self.is_single_seq = len(seqs) == 1
734
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
735

736
        self.sampling_params = sampling_params
737
738
739
740
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
741
                                      time_in_queue=None)
742
        self.last_token_latency = 0.0
743
        self.lora_request = lora_request
744
        self.prompt_logprobs: Optional[PromptLogprobs] = None
745
        self.state = SequenceGroupState()
746
        self.pooling_params = pooling_params
747
        self.pooled_data = pooled_data
748
        self.encoder_seq = encoder_seq
749
        self.trace_headers = trace_headers
750
        self.priority = priority
751

752
753
        self.cached_request_output = None

754
    @property
755
    def prompt(self) -> Optional[str]:
756
        return self.first_seq.prompt
757
758

    @property
759
    def prompt_token_ids(self) -> list[int]:
760
        return self.first_seq.prompt_token_ids
761

762
763
764
765
766
767
768
769
770
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
771
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
772
773
774
775
776
777
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

778
    @property
779
    def token_type_ids(self) -> Optional[list[int]]:
780
781
        return self.first_seq.token_type_ids

782
    @property
783
    def multi_modal_data(self) -> MultiModalKwargs:
784
785
786
787
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
788
        return MultiModalKwargs({})
Woosuk Kwon's avatar
Woosuk Kwon committed
789

790
791
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
792
793
794
795
796
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
797

798
799
800
801
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

802
803
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
804
805
        self.state.current_step = 0

806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

831
    def set_last_token_time(self, now: float) -> None:
832
        """Sets the last token time for Request level timings."""
833
834
835
836
837
        if not envs.VLLM_ZERO_OVERHEAD:
            # If still in prefill phase, assertion fails.
            assert not self.is_prefill(), (
                "seq_group.set_last_token_time() should not be called "
                "if the seq_group is in prefill phase.")
838
        self.last_token_latency = now - self.metrics.last_token_time
839
        self.metrics.last_token_time = now
840
841
842

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
843
844
845
846
        if not envs.VLLM_ZERO_OVERHEAD:
            assert not self.is_prefill(), (
                "seq_group.get_last_token_latency() should not be called "
                "if the seq_group is in prefill phase.")
847
        return self.last_token_latency
848

849
850
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
851
852
853
854
855
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
856
                and self.first_seq.get_output_len() == 1):
857
858
859
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
860
861
        """Sets the first scheduled time and time in queue for Request
        level timings."""
862
863
864
865
866
867
868
869
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

870
871
872
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
873
874
875
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
876

877
878
879
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
880
    ) -> list[Sequence]:
881
882
        if status is None:
            return self.seqs
883

884
885
886
887
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
888

889
890
891
892
893
894
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

895
    def get_finished_seqs(self) -> list[Sequence]:
896
897
898
899
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
900

901
902
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
903
904
905
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
906
907

    def get_num_uncomputed_tokens(self) -> int:
908
        num_uncomputed_tokens = 0
909
910
911
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
912
        return num_uncomputed_tokens
913

914
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
915
916
917
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
918
            return len(self.seqs)
919

920
921
922
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

923
        return len(self.get_seqs(status))
924

925
    def num_finished_seqs(self) -> int:
926
927
928
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
929

Woosuk Kwon's avatar
Woosuk Kwon committed
930
    def is_finished(self) -> bool:
931
932
933
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
934

935
    def is_prefill(self) -> bool:
936
        return self.first_seq.is_prefill()
937

Woosuk Kwon's avatar
Woosuk Kwon committed
938
    def __repr__(self) -> str:
939
940
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
941
                f"num_seqs={len(self.seqs)})")
942

943
944
945
946
    def uses_prompt_embeds(self) -> bool:
        """Returns True if the sequence group uses input embeds."""
        return any(seq.data.prompt_embeds is not None for seq in self.seqs)

947

948
949
950
951
952
953
954
955
956
957
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
958
    seq_data_delta: dict[int, SequenceDataDelta]
959
    request_id: str
960
    block_tables: dict[int, list[int]]
961
962
963
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
964
    computed_block_nums: Optional[list[int]] = None
965
966
967
968
969
970
971
972
973
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
974
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
975
976
977
978
979
980
981
982

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
983
984
985
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
986
987
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
988
        lora_request: LoRA request.
989
990
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
991
        state: Internal state tied to this sequence group.
992
        multi_modal_data: Multi modal data.
993
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
994
        encoder_seq_data: Optional sequence data for encoder prompt
995
                          (SequenceGroup.encoder_seq). Should be None
996
997
998
999
1000
1001
1002
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
1003
    """
1004

1005
1006
    request_id: str
    is_prompt: bool
1007
    seq_data: dict[int, SequenceData]
1008
    sampling_params: Optional[SamplingParams]
1009
    block_tables: dict[int, list[int]]
1010
1011
1012
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
1013
    computed_block_nums: Optional[list[int]] = None
1014
1015
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
1016
    token_type_ids: Optional[list[int]] = None
1017
    multi_modal_data: Optional[MultiModalKwargs] = None
1018
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
1019
    encoder_seq_data: Optional[SequenceData] = None
1020
    cross_block_table: Optional[list[int]] = None
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
1035
            else:
1036
                self.token_chunk_size = 1
1037

1038
1039
1040
1041
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1056
1057
1058
1059
1060
1061
1062
1063
1064
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1065

1066
    def finish_step(self) -> None:
1067
        assert self.state is not None
1068
1069
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1070
1071
        self.state.current_step += 1

1072

1073
1074
1075
1076
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1077
1078
1079
1080
1081
1082
1083
1084
1085
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1086
1087
    parent_seq_id: int
    output_token: int
1088
    logprobs: dict[int, Logprob]
1089
    output_embed: Optional[torch.Tensor] = None
1090
1091

    def __repr__(self) -> str:
1092
1093
        output_embed_shape = \
            self.output_embed.shape if self.output_embed is not None else None
Zhuohan Li's avatar
Zhuohan Li committed
1094
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1095
                f"output_token={self.output_token}, "
1096
                f"output_embed.shape={output_embed_shape}, "
1097
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1098

1099
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1100
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1101
            raise NotImplementedError()
1102
1103
1104
1105
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1106
1107


1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1120
1121
1122
1123
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1124
    """The model output associated with a completion sequence group."""
1125
    __metaclass__ = SequenceGroupOutput
1126
    samples: list[SequenceOutput]
1127
1128
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1129
    step_index: Optional[int] = 0
1130
1131

    def __repr__(self) -> str:
1132
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1133
1134
                f"prompt_logprobs={self.prompt_logprobs})")

1135
    def __eq__(self, other: object) -> bool:
1136
        if not isinstance(other, CompletionSequenceGroupOutput):
1137
1138
1139
1140
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1141

1142
class PoolingSequenceGroupOutput(
1143
1144
1145
1146
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1147
    """The model output associated with a pooling sequence group."""
1148
    __metaclass__ = SequenceGroupOutput
1149
1150
1151
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1152

1153
1154
1155
1156
    def get_data_nbytes(self) -> int:
        data: torch.Tensor = self.data
        return data.nbytes

1157
    def __repr__(self) -> str:
1158
        return f"PoolingSequenceGroupOutput(data={self.data}"
1159
1160

    def __eq__(self, other: object) -> bool:
1161
        if not isinstance(other, PoolingSequenceGroupOutput):
1162
            raise NotImplementedError()
1163
        return self.data == other.data
1164
1165


1166
1167
1168
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1169
1170
1171
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
1172
1173
1174
    
    Each stage also needs to handle its own finished_sending and 
    finished_recving in case of kv transfer.
1175
1176
    """

1177
    tensors: dict[str, torch.Tensor]
1178
1179
1180
    # [req_ids]
    finished_sending: Optional[set[str]] = None
    finished_recving: Optional[set[str]] = None
1181

1182
1183
1184
1185
1186
1187
1188
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1189
1190
1191
1192
1193
1194
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1195
    def __setitem__(self, key: str, value: torch.Tensor):
1196
1197
        self.tensors[key] = value

1198
1199
1200
    def items(self):
        return self.tensors.items()

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1211
1212
1213
1214
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1215
    """The output from a pooling operation in the pooling model."""
1216
    outputs: list[PoolingSequenceGroupOutput]
1217

1218
1219
1220
    def get_data_nbytes(self) -> int:
        return sum(o.get_data_nbytes() for o in self.outputs)

1221
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1222
1223
        return self.outputs[idx]

1224
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1235
def get_all_seq_ids(
1236
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1237
1238
1239
1240
1241
1242
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1243
def get_all_seq_ids_and_request_ids(
1244
1245
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1246
1247
1248
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1249
1250
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1251
1252
1253
1254
1255
1256
1257
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1258
1259
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1260
1261
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1262
    the target model to the proposer model.
1263
1264
1265

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1266
1267
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1268
    hidden_states: torch.Tensor
1269
    # The sequence group metadata list. Only needed for decode step.
1270
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1271
1272
1273
1274
1275
1276
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1277
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1278
1279

    def __post_init__(self):
1280
1281
1282
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1283
1284

    @property
1285
    def seq_ids(self) -> list[int]:
1286
        return self._seq_ids
1287

1288
1289
    def update(self,
               hidden_states: torch.Tensor,
1290
               seq_group_metadata_list: list[SequenceGroupMetadata],
1291
1292
1293
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1294
        assert len(seq_group_metadata_list) == len(hidden_states)
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
        # self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        # self.hidden_states = torch.cat([self.hidden_states, hidden_states])

        # if self.second_last_token_hidden_states is not None:
        #     # Adding dummy hidden_states to this to maintain same shape
        #     self.second_last_token_hidden_states = torch.cat([
        #         self.second_last_token_hidden_states,
        #         torch.zeros_like(hidden_states)
        #         if second_last_token_hidden_states is None else
        #         second_last_token_hidden_states
        #     ])
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        diff_seq_ids = [item for item in self._seq_ids if item not in seq_ids]
        index = [self._seq_ids.index(seq_id) for seq_id in diff_seq_ids]
        self._seq_ids = diff_seq_ids
        self.hidden_states = self.hidden_states[index]
1311
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])
1312
        
1313
1314
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
1315
            self.second_last_token_hidden_states = self.second_last_token_hidden_states[index]
1316
1317
1318
1319
1320
1321
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])
1322
1323
        self._seq_ids.extend(seq_ids)
        
1324

1325
    def prune(self,
1326
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1327
1328
1329
1330
1331
1332
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1333
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1334
1335
        # Only keep sequence IDs that exist in self._seq_ids
        seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
1336
        if seq_ids != self._seq_ids:
1337
            # Batch contents changed - prune removed sequences.
1338
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1339
            self.hidden_states = self.hidden_states[index]
1340
1341
1342
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1343
            self._seq_ids = seq_ids
1344

1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1363

1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
class Logits(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
    """Logits corresponding to in-progress sequences.
    Used in speculative decoding to pass lm_head logits from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the logits tensor"""
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
    logits: torch.Tensor
    # The sequence group metadata list. Only needed for decode step.
zhuwenwen's avatar
zhuwenwen committed
1376
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1377

zhuwenwen's avatar
zhuwenwen committed
1378
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1379
1380
1381
1382
1383
1384
1385

    def __post_init__(self):
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.logits)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)

    @property
zhuwenwen's avatar
zhuwenwen committed
1386
    def seq_ids(self) -> list[int]:
1387
1388
1389
1390
        return self._seq_ids
    
    def update(self,
               logits: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
1391
               seq_group_metadata_list: list[SequenceGroupMetadata]):
1392
1393
1394
1395
1396
1397
1398
        """Update hidden states from target model invocation. Only used for
        decode steps"""
        assert len(seq_group_metadata_list) == len(logits)
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.logits = torch.cat([self.logits, logits])

    def prune(self,
zhuwenwen's avatar
zhuwenwen committed
1399
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self._seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
            self.logits = self.logits[index]
            self._seq_ids = seq_ids


1414
1415
1416
1417
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1418
1419
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1420
    # The sequence group metadata list.
1421
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1422
                                        SequenceGroupMetadataDelta]]
1423
    # Blocks to swap in. List of CPU -> GPU block number.
1424
    blocks_to_swap_in: list[tuple[int,
1425
                                  int]] = msgspec.field(default_factory=list)
1426
    # Blocks to swap out. List of GPU -> CPU block number.
1427
    blocks_to_swap_out: list[tuple[int,
1428
                                   int]] = msgspec.field(default_factory=list)
1429
    # Blocks to copy. Source to dest block.
1430
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1431
1432
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1433
1434
1435
1436
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1437
1438
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1439
1440
    # Optional logits from prior step.
    previous_logits: Optional[Logits] = None
1441
1442
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1443
    # Finished request ids since last step.
1444
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1445
1446
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1447
1448
    # Async callback
    async_callback: Optional[Callable] = None
1449

1450
1451
1452
1453
1454
1455
    # Optional tree attention mask from draft model.
    tree_attn_masks: Optional[torch.Tensor] = None

    # Optional tree position ids from draft model.
    tree_position_ids: Optional[torch.Tensor] = None

1456
1457
1458
    # Optional slot mapping of kvcache that pending to be moved generated from draft model.
    kvcache_slot_to_be_moved: Optional[torch.Tensor] = None

1459
1460
1461
1462
1463
1464
    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1465
        assert first_seq_group.state is not None
1466
1467
1468
1469
1470
1471
1472
1473
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1474
        assert first_seq_group.state is not None
1475
        return first_seq_group.state.remaining_steps == 1
1476
1477
1478
1479
1480
1481

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1482
1483
1484
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1485
1486

    def clone(
1487
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1488
                                                  SequenceGroupMetadataDelta]]
1489
1490
1491
1492
1493
1494
1495
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1496
            virtual_engine=self.virtual_engine,
1497
1498
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1499
            previous_hidden_states=self.previous_hidden_states,
1500
            previous_logits=self.previous_logits,
1501
            num_steps=self.num_steps,
1502
1503
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1504
            if self.last_sampled_token_ids is not None else None,
1505
1506
            async_callback=self.async_callback,
            tree_attn_masks=self.tree_attn_masks,
1507
1508
            tree_position_ids=self.tree_position_ids,
            kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)
1509
1510
1511
1512
1513
1514
1515
1516
1517


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1518
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1519
1520

    # seq ids to be finished
1521
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1522
1523

    # seq id to finished sequences
1524
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1561
            params = original_params.clone()
1562
1563
1564
            params.n = 1
            if params.seed is not None:
                params.seed += i
1565
            seq_group = engine._add_processed_request(
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1585
            pooled_data=seq_group.pooled_data,
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1598
1599
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1600
        if self.streaming:
1601
1602
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1603
1604
1605
1606
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1607
        # when the last sequences finishes, and then return None for the
1608
        # rest of the time
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None