sequence.py 58.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sequence and its related classes."""
4
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
from abc import ABC, abstractmethod
7
from array import array
8
from collections import defaultdict
9
10
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
11
from dataclasses import dataclass, field
12
from functools import reduce
13
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
import msgspec
16
17
import torch

18
from vllm.inputs import SingletonInputs
19
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
21
from vllm.pooling_params import PoolingParams
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
25
26
27
if TYPE_CHECKING:
    from vllm.v1.worker.kv_connector_model_runner_mixin import (
        KVConnectorOutput)

28
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
29

30
31
VLLM_INVALID_TOKEN_ID = -1

32

33
def array_full(token_id: int, count: int):
34
    """[`array`][] equivalent of [numpy.full][]."""
35
36
37
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


38
39
40
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
41
42
@dataclass
class Logprob:
43
44
45
46
47
48
49
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
50
    logprob: float
51
    rank: Optional[int] = None
52
53
54
    decoded_token: Optional[str] = None


55
56
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
57
PromptLogprobs = list[Optional[dict[int, Logprob]]]
58
# {token_id -> logprob} for each sequence group.
59
SampleLogprobs = list[dict[int, Logprob]]
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
class SequenceStatus(enum.IntEnum):
63
    """Status of a sequence."""
64
65
66
67
68
69
70
71
72
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
73
74
75

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
76
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
77
78
79
80
81
82
83

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
84
85
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
86
        elif status == SequenceStatus.FINISHED_IGNORED:
87
88
89
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
90
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
91
92
93
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
94

95

96
97
98
99
100
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


101
102
103
104
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

105
    Attributes:
106
107
108
109
110
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
111
112
113
114
115
116
117
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
118
119
120
121
122
123
124
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
125
126
127
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
128
129


130
131
132
133
134
135
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
136
    new_output_token_ids: list[int]
137
138
139
140
141
142
143
144
145
146
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
147
148
149
150
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
151
152
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
153
154
155
156
157
158

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
159
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
160
161
162
163
164
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

165
166
167
    _prompt_embeds: Optional[torch.Tensor] = None
    _output_embeds: Optional[torch.Tensor] = None

168
169
    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
170
    _prompt_token_ids_tuple: tuple[int,
171
172
173
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
174
175
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
176
    _stage: SequenceStage = SequenceStage.PREFILL
177
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
178
    _cached_all_token_embeds: Optional[torch.Tensor] = None
179
180
181

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
182
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
183

184
185
186
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

187
    @staticmethod
188
    def from_prompt_token_counts(
189
            *token_counts: tuple[int, int]) -> "SequenceData":
190
        """
191
192
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        by concatenating prompt token sequences.
193
194

        Each tuple represents one token sequence, expressed in the form
195
        `(token_id, count)`.
196
        """
197
        if len(token_counts) == 0:
198
199
            return SequenceData.from_seqs([])

200
201
202
203
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
204

205
        return SequenceData(prompt_token_ids_arr)
206
207
208
209
210

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
211
212
        *,
        prompt_embeds: Optional[torch.Tensor] = None,
213
    ) -> "SequenceData":
214
        """
215
216
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        from prompt and output token sequences.
217
        """
218
219
220
221
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
222
223
            return SequenceData(prompt_token_ids_arr,
                                _prompt_embeds=prompt_embeds)
224
225
226
227
228

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
229
230
                            _output_token_ids=output_token_ids_arr,
                            _prompt_embeds=prompt_embeds)
231

232
233
234
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
235
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
236
            self._prompt_token_ids)
237
        self._update_cached_all_tokens()
238
239
        if self._prompt_embeds is not None:
            self._update_cached_all_token_embeds()
240
241

    def _update_cached_all_tokens(self):
242
243
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
244
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
245
                                                     self._output_token_ids)
246

247
248
249
250
251
252
253
    def _update_cached_all_token_embeds(self):
        assert isinstance(self._prompt_embeds, torch.Tensor)
        self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
        if self._output_embeds is not None:
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds, self._output_embeds), dim=0)

254
255
256
257
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

258
    @property
259
    def prompt_token_ids(self) -> tuple[int, ...]:
260
261
262
263
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
264
        raise NotImplementedError
265

266
267
    @property
    def prompt_token_ids_array(self) -> array:
268
269
270
271
272
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
273
274
        return self._prompt_token_ids

275
    @property
276
    def output_token_ids(self) -> tuple[int, ...]:
277
278
279
        return tuple(self._output_token_ids)

    @output_token_ids.setter
280
281
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
282
283
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
284
285
        self._update_cached_all_tokens()

286
287
288
289
290
291
292
293
294
    @property
    def output_embeds(self) -> Optional[torch.Tensor]:
        return self._output_embeds

    @output_embeds.setter
    def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
        self._output_token_embeds = new_output_token_embeds
        self._update_cached_all_token_embeds()

295
296
    @property
    def output_token_ids_array(self) -> array:
297
298
299
300
301
302
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
303
304
        return self._output_token_ids

305
306
307
308
309
310
311
312
313
    @property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        return self._prompt_embeds

    @prompt_embeds.setter
    def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
        self._prompt_embeds = prompt_embeds
        self._update_cached_all_token_embeds()

314
315
316
317
318
319
320
321
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

322
323
324
325
    def append_token_id(self,
                        token_id: int,
                        logprob: float,
                        token_embed: Optional[torch.Tensor] = None) -> None:
326
        self._output_token_ids.append(token_id)
327
        self._new_appended_tokens.append(token_id)
328
        self._cached_all_token_ids.append(token_id)
329
        self._cumulative_logprob += logprob
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        if token_embed is not None:
            # Do not pass in with batch or sequence dimensions
            assert token_embed.ndim == 1
            token_embed = token_embed.detach().cpu().unsqueeze(0)
            if self._output_embeds is None:
                self._output_embeds = token_embed
            else:
                self._output_embeds = torch.cat(
                    (self._output_embeds, token_embed), dim=0)
            assert self._cached_all_token_embeds is not None
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds,
                 token_embed.to(device=self._cached_all_token_embeds.device)),
                dim=0)
344
345

    def get_len(self) -> int:
346
        return len(self._output_token_ids) + len(self._prompt_token_ids)
347

348
    def get_prompt_len(self) -> int:
349
        return len(self._prompt_token_ids)
350

351
    def get_output_len(self) -> int:
352
        return len(self._output_token_ids)
353

354
    def get_token_ids(self) -> list[int]:
355
        return self._cached_all_token_ids
356

357
358
359
    def get_token_embeddings(self) -> Optional[torch.Tensor]:
        return self._cached_all_token_embeds

360
361
    def get_prefix_token_ids(
            self, num_tokens: int
362
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
363
        """Get prefix tokens, and make the return value hashable"""
364
        prompt_length = self.get_prompt_len()
365
366
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
367
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
368
369
370
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

371
372
373
374
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

375
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
376
377
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
378
379
380
381
382
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
383

384
385
386
387
388
389
390
391
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

392
    def reset_state_for_recompute(self) -> None:
393
394
395
396
397
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
398
        self._stage = SequenceStage.PREFILL
399
        self._new_appended_tokens = []
400
401

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
402
        """Return the number of prefill tokens that are not computed."""
403
404
405
406
407
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

408
    def get_last_token_id(self) -> int:
409
410
411
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
412

413
    def get_prompt_token_ids(self) -> tuple[int, ...]:
414
415
        return self.prompt_token_ids

416
    def get_output_token_ids(self) -> tuple[int, ...]:
417
418
        return self.output_token_ids

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

434
435
436
437
    @property
    def stage(self) -> SequenceStage:
        return self._stage

438
439
    def __repr__(self) -> str:
        return (f"SequenceData("
440
                f"prompt_token_ids={self._prompt_token_ids}, "
441
442
                f"prompt_embeds.shape="
                f"{getattr(self._prompt_embeds, 'shape', None)}, "
443
444
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
445
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
446
447


Woosuk Kwon's avatar
Woosuk Kwon committed
448
class Sequence:
449
    """Stores the data, status, and block information of a sequence.
450

451
452
453
454
455
    The sequence is constructed from the
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
    or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    (for encoder-decoder) instance passed in through the `inputs`
    constructor argument.
456

457
458
    Args:
        seq_id: The ID of the sequence.
459
        inputs: The inputs of the sequence.
460
461
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
462
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
463
        lora_request: LoRA request.
464
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
465
466

    def __init__(
467
468
        self,
        seq_id: int,
469
        inputs: SingletonInputs,
470
471
472
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
473
474
    ) -> None:
        self.seq_id = seq_id
475
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
476
        self.block_size = block_size
477
        self.eos_token_id = eos_token_id
478
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
479

480
481
482
483
        self.data = SequenceData.from_seqs(
            self.prompt_token_ids,
            prompt_embeds=self.inputs["prompt_embeds"]
            if self.inputs["type"] == "embeds" else None)
484
        self.output_logprobs: SampleLogprobs = []
485
        self.output_text = ""
486

487
        self.status = SequenceStatus.WAITING
488
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
489

490
        # These are used to keep track of delta outputs
491
        self._last_output_token_ids_offset: int = 0
492
493
        self._last_output_text_offset: int = 0

494
495
496
497
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
498
        self.tokens: Optional[list[str]] = None
499

500
501
    @property
    def n_blocks(self) -> int:
502
        return (self.get_len() + self.block_size - 1) // self.block_size
503

504
    @property
505
    def prompt(self) -> Optional[str]:
506
507
        if self.inputs["type"] == "embeds":
            return None
508
        return self.inputs.get("prompt")
509

510
    @property
511
    def prompt_token_ids(self) -> list[int]:
512
513
        if self.inputs["type"] == "embeds":
            return [0] * len(self.inputs["prompt_embeds"])
514
        return self.inputs["prompt_token_ids"]
515

516
    @property
517
    def token_type_ids(self) -> list[int]:
518
519
        if self.inputs["type"] == "embeds":
            return []
520
        return self.inputs.get("token_type_ids", [])
521

522
    @property
523
524
525
526
527
    def multi_modal_data(self) -> MultiModalKwargs:
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_kwargs"]

        return MultiModalKwargs({})
528

529
530
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
531
532
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_placeholders"]
533

534
        return {}
535

536
537
538
539
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

540
541
542
543
544
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

545
546
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
547
548
549
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
550
551
552
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
553
554
555
556
557
558
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

559
560
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
561
562
563
564
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
565
566
567
568
569
570
571
572
573
574
575
576
577

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

578
579
580
        if num_new_tokens == 0:
            return []

581
        return self.data._cached_all_token_ids[-num_new_tokens:]
582

583
    def hash_of_block(self, logical_idx: int) -> int:
584
585
        # TODO This can produce incorrect hash when block size > prompt size

586
        # Compute the number of tokens in the sequence
587
588
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
589
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
590
591
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
592

593
594
595
596
597
598
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
599
        if self.lora_int_id == 0:
600
601
602
603
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
604
        return hash(self.lora_int_id)
605

606
607
608
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

609
610
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
611
        self.data.reset_state_for_recompute()
612

613
614
615
616
    def append_token_id(self,
                        token_id: int,
                        logprobs: dict[int, Logprob],
                        token_embed: Optional[torch.Tensor] = None) -> None:
617
618
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
619
620
        self.data.append_token_id(token_id, logprobs[token_id].logprob,
                                  token_embed)
621

Woosuk Kwon's avatar
Woosuk Kwon committed
622
    def get_len(self) -> int:
623
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
624

625
626
627
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

628
629
630
    def get_output_len(self) -> int:
        return self.data.get_output_len()

631
    def get_token_ids(self) -> list[int]:
632
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
633

634
    def get_prompt_token_ids(self) -> tuple[int, ...]:
635
636
        return self.data.get_prompt_token_ids()

637
    def get_last_token_id(self) -> int:
638
        return self.data.get_last_token_id()
639

640
    def get_output_token_ids(self) -> tuple[int, ...]:
641
        return self.data.get_output_token_ids()
642
643
644
645

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

646
647
648
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

649
650
651
652
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
653

654
655
656
657
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
658
659
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
660
661
662
663
664
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

665
666
667
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

668
669
670
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
671
    def __repr__(self) -> str:
672
673
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
674
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
675

Woosuk Kwon's avatar
Woosuk Kwon committed
676

677
678
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
679
680
681
682
683
684
685
686
687
688
689
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
690
class SequenceGroup:
691
692
693
694
695
696
697
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
698
        lora_request: LoRA request.
699
        pooling_params: The parameters used to generate the pooler
700
            for a pooling model.
701
        pooled_data: The extracted hidden states from a pooling model.
702
703
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
704
        trace_headers: OpenTelemetry trace headers.
705
        priority: User-defined priority of the request.
706
        draft_size: The number of speculative tokens plus one from the target
707
                    model; equal to max number of tokens a step can generate
708
                    for single-draft speculative decoding but larger than
709
                    that for multi-draft SD (currently not supported).
710
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
711

712
713
714
715
716
717
718
719
720
721
722
723
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
724
        self.request_id = request_id
725
        self.seqs = seqs
726
        self.first_seq = seqs[0]
727
        self.arrival_time = arrival_time
728
        self.is_single_seq = len(seqs) == 1
729
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
730

731
        self.sampling_params = sampling_params
732
733
734
735
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
736
                                      time_in_queue=None)
737
        self.last_token_latency = 0.0
738
        self.lora_request = lora_request
739
        self.prompt_logprobs: Optional[PromptLogprobs] = None
740
        self.state = SequenceGroupState()
741
        self.pooling_params = pooling_params
742
        self.pooled_data = pooled_data
743
        self.encoder_seq = encoder_seq
744
        self.trace_headers = trace_headers
745
        self.priority = priority
746

747
748
        self.cached_request_output = None

749
    @property
750
    def prompt(self) -> Optional[str]:
751
        return self.first_seq.prompt
752
753

    @property
754
    def prompt_token_ids(self) -> list[int]:
755
        return self.first_seq.prompt_token_ids
756

757
758
759
760
761
762
763
764
765
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
766
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
767
768
769
770
771
772
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

773
    @property
774
    def token_type_ids(self) -> Optional[list[int]]:
775
776
        return self.first_seq.token_type_ids

777
    @property
778
    def multi_modal_data(self) -> MultiModalKwargs:
779
780
781
782
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
783
        return MultiModalKwargs({})
Woosuk Kwon's avatar
Woosuk Kwon committed
784

785
786
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
787
788
789
790
791
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
792

793
794
795
796
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

797
798
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
799
800
        self.state.current_step = 0

801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

826
    def set_last_token_time(self, now: float) -> None:
827
        """Sets the last token time for Request level timings."""
828
829
830
831
832
        # If still in prefill phase, assertion fails.
        assert not self.is_prefill(), (
            "seq_group.set_last_token_time() should not be called "
            "if the seq_group is in prefill phase.")
        self.last_token_latency = now - self.metrics.last_token_time
833
        self.metrics.last_token_time = now
834
835
836
837
838
839
840

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
        assert not self.is_prefill(), (
            "seq_group.get_last_token_latency() should not be called "
            "if the seq_group is in prefill phase.")
        return self.last_token_latency
841

842
843
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
844
845
846
847
848
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
849
                and self.first_seq.get_output_len() == 1):
850
851
852
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
853
854
        """Sets the first scheduled time and time in queue for Request
        level timings."""
855
856
857
858
859
860
861
862
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

863
864
865
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
866
867
868
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
869

870
871
872
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
873
    ) -> list[Sequence]:
874
875
        if status is None:
            return self.seqs
876

877
878
879
880
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
881

882
883
884
885
886
887
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

888
    def get_finished_seqs(self) -> list[Sequence]:
889
890
891
892
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
893

894
895
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
896
897
898
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
899
900

    def get_num_uncomputed_tokens(self) -> int:
901
        num_uncomputed_tokens = 0
902
903
904
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
905
        return num_uncomputed_tokens
906

907
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
908
909
910
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
911
            return len(self.seqs)
912

913
914
915
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

916
        return len(self.get_seqs(status))
917

918
    def num_finished_seqs(self) -> int:
919
920
921
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
922

Woosuk Kwon's avatar
Woosuk Kwon committed
923
    def is_finished(self) -> bool:
924
925
926
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
927

928
    def is_prefill(self) -> bool:
929
        return self.first_seq.is_prefill()
930

Woosuk Kwon's avatar
Woosuk Kwon committed
931
    def __repr__(self) -> str:
932
933
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
934
                f"num_seqs={len(self.seqs)})")
935

936
937
938
939
    def uses_prompt_embeds(self) -> bool:
        """Returns True if the sequence group uses input embeds."""
        return any(seq.data.prompt_embeds is not None for seq in self.seqs)

940

941
942
943
944
945
946
947
948
949
950
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
951
    seq_data_delta: dict[int, SequenceDataDelta]
952
    request_id: str
953
    block_tables: dict[int, list[int]]
954
955
956
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
957
    computed_block_nums: Optional[list[int]] = None
958
959
960
961
962
963
964
965
966
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
967
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
968
969
970
971
972
973
974
975

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
976
977
978
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
979
980
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
981
        lora_request: LoRA request.
982
983
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
984
        state: Internal state tied to this sequence group.
985
        multi_modal_data: Multi modal data.
986
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
987
        encoder_seq_data: Optional sequence data for encoder prompt
988
                          (SequenceGroup.encoder_seq). Should be None
989
990
991
992
993
994
995
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
996
    """
997

998
999
    request_id: str
    is_prompt: bool
1000
    seq_data: dict[int, SequenceData]
1001
    sampling_params: Optional[SamplingParams]
1002
    block_tables: dict[int, list[int]]
1003
1004
1005
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
1006
    computed_block_nums: Optional[list[int]] = None
1007
1008
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
1009
    token_type_ids: Optional[list[int]] = None
1010
    multi_modal_data: Optional[MultiModalKwargs] = None
1011
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
1012
    encoder_seq_data: Optional[SequenceData] = None
1013
    cross_block_table: Optional[list[int]] = None
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
1028
            else:
1029
                self.token_chunk_size = 1
1030

1031
1032
1033
1034
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1049
1050
1051
1052
1053
1054
1055
1056
1057
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1058

1059
    def finish_step(self) -> None:
1060
        assert self.state is not None
1061
1062
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1063
1064
        self.state.current_step += 1

1065

1066
1067
1068
1069
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1070
1071
1072
1073
1074
1075
1076
1077
1078
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1079
1080
    parent_seq_id: int
    output_token: int
1081
    logprobs: dict[int, Logprob]
1082
    output_embed: Optional[torch.Tensor] = None
1083
1084

    def __repr__(self) -> str:
1085
1086
        output_embed_shape = \
            self.output_embed.shape if self.output_embed is not None else None
Zhuohan Li's avatar
Zhuohan Li committed
1087
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1088
                f"output_token={self.output_token}, "
1089
                f"output_embed.shape={output_embed_shape}, "
1090
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1091

1092
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1093
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1094
            raise NotImplementedError()
1095
1096
1097
1098
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1099
1100


1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1113
1114
1115
1116
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1117
    """The model output associated with a completion sequence group."""
1118
    __metaclass__ = SequenceGroupOutput
1119
    samples: list[SequenceOutput]
1120
1121
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1122
    step_index: Optional[int] = 0
1123
1124

    def __repr__(self) -> str:
1125
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1126
1127
                f"prompt_logprobs={self.prompt_logprobs})")

1128
    def __eq__(self, other: object) -> bool:
1129
        if not isinstance(other, CompletionSequenceGroupOutput):
1130
1131
1132
1133
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1134

1135
class PoolingSequenceGroupOutput(
1136
1137
1138
1139
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1140
    """The model output associated with a pooling sequence group."""
1141
    __metaclass__ = SequenceGroupOutput
1142
1143
1144
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1145

1146
1147
1148
1149
    def get_data_nbytes(self) -> int:
        data: torch.Tensor = self.data
        return data.nbytes

1150
    def __repr__(self) -> str:
1151
        return f"PoolingSequenceGroupOutput(data={self.data}"
1152
1153

    def __eq__(self, other: object) -> bool:
1154
        if not isinstance(other, PoolingSequenceGroupOutput):
1155
            raise NotImplementedError()
1156
        return self.data == other.data
1157
1158


1159
1160
1161
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1162
1163
1164
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
1165
    
1166
    Each stage also needs to handle its own kv_connector_output.
1167
1168
    """

1169
    tensors: dict[str, torch.Tensor]
1170
    kv_connector_output: Optional["KVConnectorOutput"]
1171

1172
1173
1174
1175
1176
1177
1178
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1179
1180
1181
1182
1183
1184
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1185
    def __setitem__(self, key: str, value: torch.Tensor):
1186
1187
        self.tensors[key] = value

1188
1189
1190
    def items(self):
        return self.tensors.items()

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1201
1202
1203
1204
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1205
    """The output from a pooling operation in the pooling model."""
1206
    outputs: list[PoolingSequenceGroupOutput]
1207

1208
1209
1210
    def get_data_nbytes(self) -> int:
        return sum(o.get_data_nbytes() for o in self.outputs)

1211
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1212
1213
        return self.outputs[idx]

1214
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1225
def get_all_seq_ids(
1226
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1227
1228
1229
1230
1231
1232
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1233
def get_all_seq_ids_and_request_ids(
1234
1235
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1236
1237
1238
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1239
1240
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1241
1242
1243
1244
1245
1246
1247
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1248
1249
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1250
1251
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1252
    the target model to the proposer model.
1253
1254
1255

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1256
1257
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1258
    hidden_states: torch.Tensor
1259
    # The sequence group metadata list. Only needed for decode step.
1260
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1261
1262
1263
1264
1265
1266
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1267
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1268
1269

    def __post_init__(self):
1270
1271
1272
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1273
1274

    @property
1275
    def seq_ids(self) -> list[int]:
1276
        return self._seq_ids
1277

1278
1279
    def update(self,
               hidden_states: torch.Tensor,
1280
               seq_group_metadata_list: list[SequenceGroupMetadata],
1281
1282
1283
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1284
        assert len(seq_group_metadata_list) == len(hidden_states)
1285
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1286
1287
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1288
1289
1290
1291
1292
1293
1294
1295
1296
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1297
    def prune(self,
1298
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1299
1300
1301
1302
1303
1304
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1305
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1306
1307
        # Only keep sequence IDs that exist in self._seq_ids
        seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
1308
        if seq_ids != self._seq_ids:
1309
            # Batch contents changed - prune removed sequences.
1310
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1311
            self.hidden_states = self.hidden_states[index]
1312
1313
1314
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1315
            self._seq_ids = seq_ids
1316

1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1335

1336
1337
1338
1339
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1340
1341
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1342
    # The sequence group metadata list.
1343
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1344
                                        SequenceGroupMetadataDelta]]
1345
    # Blocks to swap in. List of CPU -> GPU block number.
1346
    blocks_to_swap_in: list[tuple[int,
1347
                                  int]] = msgspec.field(default_factory=list)
1348
    # Blocks to swap out. List of GPU -> CPU block number.
1349
    blocks_to_swap_out: list[tuple[int,
1350
                                   int]] = msgspec.field(default_factory=list)
1351
    # Blocks to copy. Source to dest block.
1352
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1353
1354
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1355
1356
1357
1358
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1359
1360
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1361
1362
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1363
    # Finished request ids since last step.
1364
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1365
1366
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1367
1368
    # Async callback
    async_callback: Optional[Callable] = None
1369
1370
1371
1372
1373
1374
1375

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1376
        assert first_seq_group.state is not None
1377
1378
1379
1380
1381
1382
1383
1384
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1385
        assert first_seq_group.state is not None
1386
        return first_seq_group.state.remaining_steps == 1
1387
1388
1389
1390
1391
1392

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1393
1394
1395
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1396
1397

    def clone(
1398
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1399
                                                  SequenceGroupMetadataDelta]]
1400
1401
1402
1403
1404
1405
1406
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1407
            virtual_engine=self.virtual_engine,
1408
1409
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1410
            previous_hidden_states=self.previous_hidden_states,
1411
            num_steps=self.num_steps,
1412
1413
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1414
            if self.last_sampled_token_ids is not None else None,
1415
            async_callback=self.async_callback)
1416
1417
1418
1419
1420
1421
1422
1423
1424


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1425
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1426
1427

    # seq ids to be finished
1428
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1429
1430

    # seq id to finished sequences
1431
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1468
            params = original_params.clone()
1469
1470
1471
            params.n = 1
            if params.seed is not None:
                params.seed += i
1472
            seq_group = engine._add_processed_request(
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1492
            pooled_data=seq_group.pooled_data,
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1505
1506
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1507
        if self.streaming:
1508
1509
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1510
1511
1512
1513
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1514
        # when the last sequences finishes, and then return None for the
1515
        # rest of the time
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None