sequence.py 60.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sequence and its related classes."""
4
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
from abc import ABC, abstractmethod
7
from array import array
8
from collections import defaultdict
9
10
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
11
from dataclasses import dataclass, field
12
from functools import reduce
13
from typing import Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
import msgspec
16
17
import torch

18
from vllm.inputs import SingletonInputs
19
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
21
from vllm.pooling_params import PoolingParams
22
from vllm.prompt_adapter.request import PromptAdapterRequest
23
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
26

27
28
VLLM_INVALID_TOKEN_ID = -1

29

30
def array_full(token_id: int, count: int):
31
    """[`array`][] equivalent of [numpy.full][]."""
32
33
34
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


35
36
37
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
38
39
@dataclass
class Logprob:
40
41
42
43
44
45
46
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
47
    logprob: float
48
    rank: Optional[int] = None
49
50
51
    decoded_token: Optional[str] = None


52
53
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
54
PromptLogprobs = list[Optional[dict[int, Logprob]]]
55
# {token_id -> logprob} for each sequence group.
56
SampleLogprobs = list[dict[int, Logprob]]
57

Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
class SequenceStatus(enum.IntEnum):
60
    """Status of a sequence."""
61
62
63
64
65
66
67
68
69
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
73
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
74
75
76
77
78
79
80

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
81
82
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
83
        elif status == SequenceStatus.FINISHED_IGNORED:
84
85
86
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
87
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
88
89
90
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
91

92

93
94
95
96
97
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


98
99
100
101
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

102
    Attributes:
103
104
105
106
107
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
108
109
110
111
112
113
114
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
115
        spec_token_acceptance_counts: number of accepted speculative tokens at
116
                                      each position; the first token is from
117
                                      the target model and is always accepted;
118
                                      e.g., when it's [10, 8, 4, 2] for a req,
119
                                      it means there were 10 forward passes in
120
121
                                      total, and there were 8, 4, 2 accepted
                                      tokens at 1st, 2nd, 3rd speculation step.
122
123
124
125
126
127
128
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
129
130
131
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
132
    spec_token_acceptance_counts: Optional[list[int]] = None
133
134


135
136
137
138
139
140
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
141
    new_output_token_ids: list[int]
142
143
144
145
146
147
148
149
150
151
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
152
153
154
155
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
156
157
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
158
159
160
161
162
163

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
164
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
165
166
167
168
169
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

170
171
172
    _prompt_embeds: Optional[torch.Tensor] = None
    _output_embeds: Optional[torch.Tensor] = None

173
174
    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
175
    _prompt_token_ids_tuple: tuple[int,
176
177
178
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
179
180
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
181
    _stage: SequenceStage = SequenceStage.PREFILL
182
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
183
    _cached_all_token_embeds: Optional[torch.Tensor] = None
184
185
186

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
187
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
188

189
190
191
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

192
    @staticmethod
193
    def from_prompt_token_counts(
194
            *token_counts: tuple[int, int]) -> "SequenceData":
195
        """
196
197
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        by concatenating prompt token sequences.
198
199

        Each tuple represents one token sequence, expressed in the form
200
        `(token_id, count)`.
201
        """
202
        if len(token_counts) == 0:
203
204
            return SequenceData.from_seqs([])

205
206
207
208
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
209

210
        return SequenceData(prompt_token_ids_arr)
211
212
213
214
215

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
216
217
        *,
        prompt_embeds: Optional[torch.Tensor] = None,
218
    ) -> "SequenceData":
219
        """
220
221
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        from prompt and output token sequences.
222
        """
223
224
225
226
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
227
228
            return SequenceData(prompt_token_ids_arr,
                                _prompt_embeds=prompt_embeds)
229
230
231
232
233

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
234
235
                            _output_token_ids=output_token_ids_arr,
                            _prompt_embeds=prompt_embeds)
236

237
238
239
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
240
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
241
            self._prompt_token_ids)
242
        self._update_cached_all_tokens()
243
244
        if self._prompt_embeds is not None:
            self._update_cached_all_token_embeds()
245
246

    def _update_cached_all_tokens(self):
247
248
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
249
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
250
                                                     self._output_token_ids)
251

252
253
254
255
256
257
258
    def _update_cached_all_token_embeds(self):
        assert isinstance(self._prompt_embeds, torch.Tensor)
        self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
        if self._output_embeds is not None:
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds, self._output_embeds), dim=0)

259
260
261
262
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

263
    @property
264
    def prompt_token_ids(self) -> tuple[int, ...]:
265
266
267
268
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
269
        raise NotImplementedError
270

271
272
    @property
    def prompt_token_ids_array(self) -> array:
273
274
275
276
277
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
278
279
        return self._prompt_token_ids

280
    @property
281
    def output_token_ids(self) -> tuple[int, ...]:
282
283
284
        return tuple(self._output_token_ids)

    @output_token_ids.setter
285
286
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
287
288
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
289
290
        self._update_cached_all_tokens()

291
292
293
294
295
296
297
298
299
    @property
    def output_embeds(self) -> Optional[torch.Tensor]:
        return self._output_embeds

    @output_embeds.setter
    def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
        self._output_token_embeds = new_output_token_embeds
        self._update_cached_all_token_embeds()

300
301
    @property
    def output_token_ids_array(self) -> array:
302
303
304
305
306
307
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
308
309
        return self._output_token_ids

310
311
312
313
314
315
316
317
318
    @property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        return self._prompt_embeds

    @prompt_embeds.setter
    def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
        self._prompt_embeds = prompt_embeds
        self._update_cached_all_token_embeds()

319
320
321
322
323
324
325
326
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

327
328
329
330
    def append_token_id(self,
                        token_id: int,
                        logprob: float,
                        token_embed: Optional[torch.Tensor] = None) -> None:
331
        self._output_token_ids.append(token_id)
332
        self._new_appended_tokens.append(token_id)
333
        self._cached_all_token_ids.append(token_id)
334
        self._cumulative_logprob += logprob
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        if token_embed is not None:
            # Do not pass in with batch or sequence dimensions
            assert token_embed.ndim == 1
            token_embed = token_embed.detach().cpu().unsqueeze(0)
            if self._output_embeds is None:
                self._output_embeds = token_embed
            else:
                self._output_embeds = torch.cat(
                    (self._output_embeds, token_embed), dim=0)
            assert self._cached_all_token_embeds is not None
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds,
                 token_embed.to(device=self._cached_all_token_embeds.device)),
                dim=0)
349
350

    def get_len(self) -> int:
351
        return len(self._output_token_ids) + len(self._prompt_token_ids)
352

353
    def get_prompt_len(self) -> int:
354
        return len(self._prompt_token_ids)
355

356
    def get_output_len(self) -> int:
357
        return len(self._output_token_ids)
358

359
    def get_token_ids(self) -> list[int]:
360
        return self._cached_all_token_ids
361

362
363
364
    def get_token_embeddings(self) -> Optional[torch.Tensor]:
        return self._cached_all_token_embeds

365
366
    def get_prefix_token_ids(
            self, num_tokens: int
367
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
368
        """Get prefix tokens, and make the return value hashable"""
369
        prompt_length = self.get_prompt_len()
370
371
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
372
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
373
374
375
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

376
377
378
379
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

380
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
381
382
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
383
384
385
386
387
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
388

389
390
391
392
393
394
395
396
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

397
    def reset_state_for_recompute(self) -> None:
398
399
400
401
402
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
403
        self._stage = SequenceStage.PREFILL
404
        self._new_appended_tokens = []
405
406

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
407
        """Return the number of prefill tokens that are not computed."""
408
409
410
411
412
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

413
    def get_last_token_id(self) -> int:
414
415
416
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
417

418
    def get_prompt_token_ids(self) -> tuple[int, ...]:
419
420
        return self.prompt_token_ids

421
    def get_output_token_ids(self) -> tuple[int, ...]:
422
423
        return self.output_token_ids

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

439
440
441
442
    @property
    def stage(self) -> SequenceStage:
        return self._stage

443
444
    def __repr__(self) -> str:
        return (f"SequenceData("
445
                f"prompt_token_ids={self._prompt_token_ids}, "
446
447
                f"prompt_embeds.shape="
                f"{getattr(self._prompt_embeds, 'shape', None)}, "
448
449
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
450
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
451
452


Woosuk Kwon's avatar
Woosuk Kwon committed
453
class Sequence:
454
    """Stores the data, status, and block information of a sequence.
455

456
457
458
459
460
    The sequence is constructed from the
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
    or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    (for encoder-decoder) instance passed in through the `inputs`
    constructor argument.
461

462
463
    Args:
        seq_id: The ID of the sequence.
464
        inputs: The inputs of the sequence.
465
466
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
467
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
468
        lora_request: LoRA request.
469
        prompt_adapter_request: Prompt Adapter request.
470
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
471
472

    def __init__(
473
474
        self,
        seq_id: int,
475
        inputs: SingletonInputs,
476
477
478
479
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
480
481
    ) -> None:
        self.seq_id = seq_id
482
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
483
        self.block_size = block_size
484
        self.eos_token_id = eos_token_id
485
        self.lora_request = lora_request
486
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
487

488
489
490
491
        self.data = SequenceData.from_seqs(
            self.prompt_token_ids,
            prompt_embeds=self.inputs["prompt_embeds"]
            if self.inputs["type"] == "embeds" else None)
492
        self.output_logprobs: SampleLogprobs = []
493
        self.output_text = ""
494

495
        self.status = SequenceStatus.WAITING
496
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
497

498
        # These are used to keep track of delta outputs
499
        self._last_output_token_ids_offset: int = 0
500
501
        self._last_output_text_offset: int = 0

502
503
504
505
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
506
        self.tokens: Optional[list[str]] = None
507

508
509
    @property
    def n_blocks(self) -> int:
510
        return (self.get_len() + self.block_size - 1) // self.block_size
511

512
    @property
513
    def prompt(self) -> Optional[str]:
514
515
        if self.inputs["type"] == "embeds":
            return None
516
        return self.inputs.get("prompt")
517

518
    @property
519
    def prompt_token_ids(self) -> list[int]:
520
521
        if self.inputs["type"] == "embeds":
            return [0] * len(self.inputs["prompt_embeds"])
522
        return self.inputs["prompt_token_ids"]
523

524
    @property
525
    def token_type_ids(self) -> list[int]:
526
527
        if self.inputs["type"] == "embeds":
            return []
528
        return self.inputs.get("token_type_ids", [])
529

530
    @property
531
532
533
534
535
    def multi_modal_data(self) -> MultiModalKwargs:
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_kwargs"]

        return MultiModalKwargs({})
536

537
538
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
539
540
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_placeholders"]
541

542
        return {}
543

544
545
546
547
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

548
549
550
551
552
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

553
554
555
556
557
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

558
559
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
560
561
562
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
563
564
565
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
566
567
568
569
570
571
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

572
573
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
574
575
576
577
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
578
579
580
581
582
583
584
585
586
587
588
589
590

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

591
592
593
        if num_new_tokens == 0:
            return []

594
        return self.data._cached_all_token_ids[-num_new_tokens:]
595

596
    def hash_of_block(self, logical_idx: int) -> int:
597
598
        # TODO This can produce incorrect hash when block size > prompt size

599
        # Compute the number of tokens in the sequence
600
601
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
602
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
603
604
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
605

606
607
608
609
610
611
612
613
614
615
616
617
618
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
        if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
        return hash((self.prompt_adapter_id, self.lora_int_id))

619
620
621
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

622
623
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
624
        self.data.reset_state_for_recompute()
625

626
627
628
629
    def append_token_id(self,
                        token_id: int,
                        logprobs: dict[int, Logprob],
                        token_embed: Optional[torch.Tensor] = None) -> None:
630
631
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
632
633
        self.data.append_token_id(token_id, logprobs[token_id].logprob,
                                  token_embed)
634

Woosuk Kwon's avatar
Woosuk Kwon committed
635
    def get_len(self) -> int:
636
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
637

638
639
640
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

641
642
643
    def get_output_len(self) -> int:
        return self.data.get_output_len()

644
    def get_token_ids(self) -> list[int]:
645
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
646

647
    def get_prompt_token_ids(self) -> tuple[int, ...]:
648
649
        return self.data.get_prompt_token_ids()

650
    def get_last_token_id(self) -> int:
651
        return self.data.get_last_token_id()
652

653
    def get_output_token_ids(self) -> tuple[int, ...]:
654
        return self.data.get_output_token_ids()
655
656
657
658

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

659
660
661
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

662
663
664
665
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
666

667
668
669
670
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
671
672
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
673
674
675
676
677
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

678
679
680
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

681
682
683
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
684
    def __repr__(self) -> str:
685
686
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
687
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
688

Woosuk Kwon's avatar
Woosuk Kwon committed
689

690
691
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
692
693
694
695
696
697
698
699
700
701
702
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
703
class SequenceGroup:
704
705
706
707
708
709
710
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
711
        lora_request: LoRA request.
712
        pooling_params: The parameters used to generate the pooler
713
            for a pooling model.
714
        pooled_data: The extracted hidden states from a pooling model.
715
716
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
717
        trace_headers: OpenTelemetry trace headers.
718
        prompt_adapter_request: Prompt Adapter request.
719
        priority: User-defined priority of the request.
720
        draft_size: The number of speculative tokens plus one from the target
721
                    model; equal to max number of tokens a step can generate
722
                    for single-draft speculative decoding but larger than
723
                    that for multi-draft SD (currently not supported).
724
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
725

726
727
728
729
730
731
732
733
734
735
736
737
738
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 prompt_adapter_request: Optional[PromptAdapterRequest] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
739
        self.request_id = request_id
740
        self.seqs = seqs
741
        self.first_seq = seqs[0]
742
        self.arrival_time = arrival_time
743
        self.is_single_seq = len(seqs) == 1
744
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
745

746
        self.sampling_params = sampling_params
747
748
749
750
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
751
752
753
                                      time_in_queue=None,
                                      spec_token_acceptance_counts=[0] *
                                      draft_size)
754
        self.last_token_latency = 0.0
755
        self.lora_request = lora_request
756
        self.prompt_logprobs: Optional[PromptLogprobs] = None
757
        self.state = SequenceGroupState()
758
        self.pooling_params = pooling_params
759
        self.pooled_data = pooled_data
760
        self.prompt_adapter_request = prompt_adapter_request
761
        self.encoder_seq = encoder_seq
762
        self.trace_headers = trace_headers
763
        self.priority = priority
764

765
766
        self.cached_request_output = None

767
    @property
768
    def prompt(self) -> Optional[str]:
769
        return self.first_seq.prompt
770
771

    @property
772
    def prompt_token_ids(self) -> list[int]:
773
        return self.first_seq.prompt_token_ids
774

775
776
777
778
779
780
781
782
783
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
784
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
785
786
787
788
789
790
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

791
    @property
792
    def token_type_ids(self) -> Optional[list[int]]:
793
794
        return self.first_seq.token_type_ids

795
    @property
796
    def multi_modal_data(self) -> MultiModalKwargs:
797
798
799
800
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
801
        return MultiModalKwargs({})
Woosuk Kwon's avatar
Woosuk Kwon committed
802

803
804
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
805
806
807
808
809
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
810

811
812
813
814
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

815
816
817
818
819
820
821
822
823
824
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

825
826
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
827
828
        self.state.current_step = 0

829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

854
    def set_last_token_time(self, now: float) -> None:
855
        """Sets the last token time for Request level timings."""
856
857
858
859
860
        # If still in prefill phase, assertion fails.
        assert not self.is_prefill(), (
            "seq_group.set_last_token_time() should not be called "
            "if the seq_group is in prefill phase.")
        self.last_token_latency = now - self.metrics.last_token_time
861
        self.metrics.last_token_time = now
862
863
864
865
866
867
868

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
        assert not self.is_prefill(), (
            "seq_group.get_last_token_latency() should not be called "
            "if the seq_group is in prefill phase.")
        return self.last_token_latency
869

870
871
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
872
873
874
875
876
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
877
                and self.first_seq.get_output_len() == 1):
878
879
880
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
881
882
        """Sets the first scheduled time and time in queue for Request
        level timings."""
883
884
885
886
887
888
889
890
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

891
892
893
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
894
895
896
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
897

898
899
900
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
901
    ) -> list[Sequence]:
902
903
        if status is None:
            return self.seqs
904

905
906
907
908
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
909

910
911
912
913
914
915
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

916
    def get_finished_seqs(self) -> list[Sequence]:
917
918
919
920
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
921

922
923
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
924
925
926
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
927
928

    def get_num_uncomputed_tokens(self) -> int:
929
        num_uncomputed_tokens = 0
930
931
932
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
933
        return num_uncomputed_tokens
934

935
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
936
937
938
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
939
            return len(self.seqs)
940

941
942
943
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

944
        return len(self.get_seqs(status))
945

946
    def num_finished_seqs(self) -> int:
947
948
949
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
950

Woosuk Kwon's avatar
Woosuk Kwon committed
951
    def is_finished(self) -> bool:
952
953
954
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
955

956
    def is_prefill(self) -> bool:
957
        return self.first_seq.is_prefill()
958

Woosuk Kwon's avatar
Woosuk Kwon committed
959
    def __repr__(self) -> str:
960
961
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
962
                f"num_seqs={len(self.seqs)})")
963

964
965
966
967
    def uses_prompt_embeds(self) -> bool:
        """Returns True if the sequence group uses input embeds."""
        return any(seq.data.prompt_embeds is not None for seq in self.seqs)

968

969
970
971
972
973
974
975
976
977
978
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
979
    seq_data_delta: dict[int, SequenceDataDelta]
980
    request_id: str
981
    block_tables: dict[int, list[int]]
982
983
984
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
985
    computed_block_nums: Optional[list[int]] = None
986
987
988
989
990
991
992
993
994
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
995
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
996
997
998
999
1000
1001
1002
1003

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
1004
1005
1006
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
1007
1008
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
1009
        lora_request: LoRA request.
1010
1011
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
1012
        state: Internal state tied to this sequence group.
1013
        multi_modal_data: Multi modal data.
1014
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
1015
        encoder_seq_data: Optional sequence data for encoder prompt
1016
                          (SequenceGroup.encoder_seq). Should be None
1017
1018
1019
1020
1021
1022
1023
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
1024
        prompt_adapter_request: Prompt Adapter request.
1025
    """
1026

1027
1028
    request_id: str
    is_prompt: bool
1029
    seq_data: dict[int, SequenceData]
1030
    sampling_params: Optional[SamplingParams]
1031
    block_tables: dict[int, list[int]]
1032
1033
1034
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
1035
    computed_block_nums: Optional[list[int]] = None
1036
1037
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
1038
    token_type_ids: Optional[list[int]] = None
1039
    multi_modal_data: Optional[MultiModalKwargs] = None
1040
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
1041
    encoder_seq_data: Optional[SequenceData] = None
1042
    cross_block_table: Optional[list[int]] = None
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
1058
            else:
1059
                self.token_chunk_size = 1
1060

1061
1062
1063
1064
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1065
    @property
1066
1067
1068
1069
1070
1071
1072
1073
1074
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1089
1090
1091
1092
1093
1094
1095
1096
1097
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1098

1099
    def finish_step(self) -> None:
1100
        assert self.state is not None
1101
1102
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1103
1104
        self.state.current_step += 1

1105

1106
1107
1108
1109
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1110
1111
1112
1113
1114
1115
1116
1117
1118
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1119
1120
    parent_seq_id: int
    output_token: int
1121
    logprobs: dict[int, Logprob]
1122
    output_embed: Optional[torch.Tensor] = None
1123
1124

    def __repr__(self) -> str:
1125
1126
        output_embed_shape = \
            self.output_embed.shape if self.output_embed is not None else None
Zhuohan Li's avatar
Zhuohan Li committed
1127
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1128
                f"output_token={self.output_token}, "
1129
                f"output_embed.shape={output_embed_shape}, "
1130
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1131

1132
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1133
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1134
            raise NotImplementedError()
1135
1136
1137
1138
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1139
1140


1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1153
1154
1155
1156
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1157
    """The model output associated with a completion sequence group."""
1158
    __metaclass__ = SequenceGroupOutput
1159
    samples: list[SequenceOutput]
1160
1161
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1162
    step_index: Optional[int] = 0
1163
1164

    def __repr__(self) -> str:
1165
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1166
1167
                f"prompt_logprobs={self.prompt_logprobs})")

1168
    def __eq__(self, other: object) -> bool:
1169
        if not isinstance(other, CompletionSequenceGroupOutput):
1170
1171
1172
1173
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1174

1175
class PoolingSequenceGroupOutput(
1176
1177
1178
1179
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1180
    """The model output associated with a pooling sequence group."""
1181
    __metaclass__ = SequenceGroupOutput
1182
1183
1184
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1185
1186

    def __repr__(self) -> str:
1187
        return f"PoolingSequenceGroupOutput(data={self.data}"
1188
1189

    def __eq__(self, other: object) -> bool:
1190
        if not isinstance(other, PoolingSequenceGroupOutput):
1191
            raise NotImplementedError()
1192
        return self.data == other.data
1193
1194


1195
1196
1197
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1198
1199
1200
1201
1202
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

1203
    tensors: dict[str, torch.Tensor]
1204

1205
1206
1207
1208
1209
1210
1211
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1212
1213
1214
1215
1216
1217
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1218
    def __setitem__(self, key: str, value: torch.Tensor):
1219
1220
        self.tensors[key] = value

1221
1222
1223
    def items(self):
        return self.tensors.items()

1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1234
1235
1236
1237
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1238
    """The output from a pooling operation in the pooling model."""
1239
    outputs: list[PoolingSequenceGroupOutput]
1240

1241
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1242
1243
        return self.outputs[idx]

1244
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1255
def get_all_seq_ids(
1256
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1257
1258
1259
1260
1261
1262
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1263
def get_all_seq_ids_and_request_ids(
1264
1265
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1266
1267
1268
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1269
1270
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1271
1272
1273
1274
1275
1276
1277
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1278
1279
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1280
1281
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1282
    the target model to the proposer model.
1283
1284
1285

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1286
1287
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1288
    hidden_states: torch.Tensor
1289
    # The sequence group metadata list. Only needed for decode step.
1290
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1291
1292
1293
1294
1295
1296
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1297
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1298
1299

    def __post_init__(self):
1300
1301
1302
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1303
1304

    @property
1305
    def seq_ids(self) -> list[int]:
1306
        return self._seq_ids
1307

1308
1309
    def update(self,
               hidden_states: torch.Tensor,
1310
               seq_group_metadata_list: list[SequenceGroupMetadata],
1311
1312
1313
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1314
        assert len(seq_group_metadata_list) == len(hidden_states)
1315
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1316
1317
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1318
1319
1320
1321
1322
1323
1324
1325
1326
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1327
    def prune(self,
1328
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1329
1330
1331
1332
1333
1334
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1335
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1336
1337
        # Only keep sequence IDs that exist in self._seq_ids
        seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
1338
        if seq_ids != self._seq_ids:
1339
            # Batch contents changed - prune removed sequences.
1340
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1341
            self.hidden_states = self.hidden_states[index]
1342
1343
1344
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1345
            self._seq_ids = seq_ids
1346

1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1365

1366
1367
1368
1369
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1370
1371
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1372
    # The sequence group metadata list.
1373
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1374
                                        SequenceGroupMetadataDelta]]
1375
    # Blocks to swap in. List of CPU -> GPU block number.
1376
    blocks_to_swap_in: list[tuple[int,
1377
                                  int]] = msgspec.field(default_factory=list)
1378
    # Blocks to swap out. List of GPU -> CPU block number.
1379
    blocks_to_swap_out: list[tuple[int,
1380
                                   int]] = msgspec.field(default_factory=list)
1381
    # Blocks to copy. Source to dest block.
1382
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1383
1384
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1385
1386
1387
1388
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1389
1390
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1391
1392
    # The number of forward steps to run.
    num_steps: int = 1
1393
1394
    # The step index for spec model input.
    spec_step_idx: Optional[int] = None
Mor Zusman's avatar
Mor Zusman committed
1395
    # Finished request ids since last step.
1396
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1397
1398
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1399
1400
    # Async callback
    async_callback: Optional[Callable] = None
1401
1402
1403
1404
1405
1406
1407

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1408
        assert first_seq_group.state is not None
1409
1410
1411
1412
1413
1414
1415
1416
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1417
        assert first_seq_group.state is not None
1418
        return first_seq_group.state.remaining_steps == 1
1419
1420
1421
1422
1423
1424

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1425
1426
1427
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1428
1429

    def clone(
1430
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1431
                                                  SequenceGroupMetadataDelta]]
1432
1433
1434
1435
1436
1437
1438
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1439
            virtual_engine=self.virtual_engine,
1440
1441
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1442
            previous_hidden_states=self.previous_hidden_states,
1443
            num_steps=self.num_steps,
1444
1445
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1446
            if self.last_sampled_token_ids is not None else None,
1447
            async_callback=self.async_callback)
1448
1449
1450
1451
1452
1453
1454
1455
1456


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1457
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1458
1459

    # seq ids to be finished
1460
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1461
1462

    # seq id to finished sequences
1463
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1500
            params = original_params.clone()
1501
1502
1503
            params.n = 1
            if params.seed is not None:
                params.seed += i
1504
            seq_group = engine._add_processed_request(
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1524
            pooled_data=seq_group.pooled_data,
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            prompt_adapter_request=seq_group.prompt_adapter_request,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1538
1539
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1540
        if self.streaming:
1541
1542
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1543
1544
1545
1546
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1547
        # when the last sequences finishes, and then return None for the
1548
        # rest of the time
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None