sequence.py 56.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sequence and its related classes."""
4
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
from abc import ABC, abstractmethod
7
from array import array
8
from collections import defaultdict
9
10
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
11
from dataclasses import dataclass, field
12
from functools import reduce
13
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
import msgspec
16
17
import torch

18
from vllm.inputs import SingletonInputs
19
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
20
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
21
from vllm.pooling_params import PoolingParams
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
if TYPE_CHECKING:
25
    from vllm.lora.request import LoRARequest
26
27
    from vllm.v1.worker.kv_connector_model_runner_mixin import (
        KVConnectorOutput)
28
29
30
else:
    LoRARequest = Any
    KVConnectorOutput = Any
31

32
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
33

34
35
VLLM_INVALID_TOKEN_ID = -1

36

37
def array_full(token_id: int, count: int):
38
    """[`array`][] equivalent of [numpy.full][]."""
39
40
41
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


42
class SequenceStatus(enum.IntEnum):
43
    """Status of a sequence."""
44
45
46
47
48
49
50
51
52
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
53
54
55

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
56
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
57
58
59
60
61
62
63

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
64
65
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
66
        elif status == SequenceStatus.FINISHED_IGNORED:
67
68
69
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
70
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
71
72
73
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
74

75

76
77
78
79
80
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


81
82
83
84
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

85
    Attributes:
86
87
88
89
90
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
91
92
93
94
95
96
97
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
98
99
100
101
102
103
104
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
105
106
107
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
108
109


110
111
112
113
114
115
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
116
    new_output_token_ids: list[int]
117
118
119
120
121
122
123
124
125
126
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
127
    """Data associated with a sequence."""
128
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
129
130
131
132
133
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

134
135
136
    _prompt_embeds: Optional[torch.Tensor] = None
    _output_embeds: Optional[torch.Tensor] = None

137
138
    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
139
    _prompt_token_ids_tuple: tuple[int,
140
141
142
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
143
144
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
145
    _stage: SequenceStage = SequenceStage.PREFILL
146
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
147
    _cached_all_token_embeds: Optional[torch.Tensor] = None
148
149
150

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
151
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
152

153
154
155
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

156
    @staticmethod
157
    def from_prompt_token_counts(
158
            *token_counts: tuple[int, int]) -> "SequenceData":
159
        """
160
161
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        by concatenating prompt token sequences.
162
163

        Each tuple represents one token sequence, expressed in the form
164
        `(token_id, count)`.
165
        """
166
        if len(token_counts) == 0:
167
168
            return SequenceData.from_seqs([])

169
170
171
172
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
173

174
        return SequenceData(prompt_token_ids_arr)
175
176
177
178
179

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
180
181
        *,
        prompt_embeds: Optional[torch.Tensor] = None,
182
    ) -> "SequenceData":
183
        """
184
185
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        from prompt and output token sequences.
186
        """
187
188
189
190
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
191
192
            return SequenceData(prompt_token_ids_arr,
                                _prompt_embeds=prompt_embeds)
193
194
195
196
197

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
198
199
                            _output_token_ids=output_token_ids_arr,
                            _prompt_embeds=prompt_embeds)
200

201
202
203
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
204
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
205
            self._prompt_token_ids)
206
        self._update_cached_all_tokens()
207
208
        if self._prompt_embeds is not None:
            self._update_cached_all_token_embeds()
209
210

    def _update_cached_all_tokens(self):
211
212
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
213
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
214
                                                     self._output_token_ids)
215

216
217
218
219
220
221
222
    def _update_cached_all_token_embeds(self):
        assert isinstance(self._prompt_embeds, torch.Tensor)
        self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
        if self._output_embeds is not None:
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds, self._output_embeds), dim=0)

223
224
    @property
    def cumulative_logprob(self) -> float:
225
        """The cumulative log probability of the output."""
226
227
        return self._cumulative_logprob

228
    @property
229
    def prompt_token_ids(self) -> tuple[int, ...]:
230
        """The token IDs of the prompt."""
231
232
233
234
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
235
        raise NotImplementedError
236

237
238
    @property
    def prompt_token_ids_array(self) -> array:
239
240
241
242
243
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
244
245
        return self._prompt_token_ids

246
    @property
247
    def output_token_ids(self) -> tuple[int, ...]:
248
        """The token IDs of the output."""
249
250
251
        return tuple(self._output_token_ids)

    @output_token_ids.setter
252
253
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
254
255
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
256
257
        self._update_cached_all_tokens()

258
259
260
261
262
263
264
265
266
    @property
    def output_embeds(self) -> Optional[torch.Tensor]:
        return self._output_embeds

    @output_embeds.setter
    def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
        self._output_token_embeds = new_output_token_embeds
        self._update_cached_all_token_embeds()

267
268
    @property
    def output_token_ids_array(self) -> array:
269
270
271
272
273
274
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
275
276
        return self._output_token_ids

277
278
279
280
281
282
283
284
285
    @property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        return self._prompt_embeds

    @prompt_embeds.setter
    def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
        self._prompt_embeds = prompt_embeds
        self._update_cached_all_token_embeds()

286
287
288
289
290
291
292
293
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

294
295
296
297
    def append_token_id(self,
                        token_id: int,
                        logprob: float,
                        token_embed: Optional[torch.Tensor] = None) -> None:
298
        self._output_token_ids.append(token_id)
299
        self._new_appended_tokens.append(token_id)
300
        self._cached_all_token_ids.append(token_id)
301
        self._cumulative_logprob += logprob
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        if token_embed is not None:
            # Do not pass in with batch or sequence dimensions
            assert token_embed.ndim == 1
            token_embed = token_embed.detach().cpu().unsqueeze(0)
            if self._output_embeds is None:
                self._output_embeds = token_embed
            else:
                self._output_embeds = torch.cat(
                    (self._output_embeds, token_embed), dim=0)
            assert self._cached_all_token_embeds is not None
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds,
                 token_embed.to(device=self._cached_all_token_embeds.device)),
                dim=0)
316
317

    def get_len(self) -> int:
318
        return len(self._output_token_ids) + len(self._prompt_token_ids)
319

320
    def get_prompt_len(self) -> int:
321
        return len(self._prompt_token_ids)
322

323
    def get_output_len(self) -> int:
324
        return len(self._output_token_ids)
325

326
    def get_token_ids(self) -> list[int]:
327
        return self._cached_all_token_ids
328

329
330
331
    def get_token_embeddings(self) -> Optional[torch.Tensor]:
        return self._cached_all_token_embeds

332
333
    def get_prefix_token_ids(
            self, num_tokens: int
334
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
335
        """Get prefix tokens, and make the return value hashable"""
336
        prompt_length = self.get_prompt_len()
337
338
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
339
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
340
341
342
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

343
344
345
346
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

347
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
348
349
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
350
351
352
353
354
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
355

356
357
358
359
360
361
362
363
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

364
    def reset_state_for_recompute(self) -> None:
365
366
367
368
369
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
370
        self._stage = SequenceStage.PREFILL
371
        self._new_appended_tokens = []
372
373

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
374
        """Return the number of prefill tokens that are not computed."""
375
376
377
378
379
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

380
    def get_last_token_id(self) -> int:
381
382
383
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
384

385
    def get_prompt_token_ids(self) -> tuple[int, ...]:
386
387
        return self.prompt_token_ids

388
    def get_output_token_ids(self) -> tuple[int, ...]:
389
390
        return self.output_token_ids

391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

406
407
408
409
    @property
    def stage(self) -> SequenceStage:
        return self._stage

410
411
    def __repr__(self) -> str:
        return (f"SequenceData("
412
                f"prompt_token_ids={self._prompt_token_ids}, "
413
414
                f"prompt_embeds.shape="
                f"{getattr(self._prompt_embeds, 'shape', None)}, "
415
416
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
417
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
418
419


Woosuk Kwon's avatar
Woosuk Kwon committed
420
class Sequence:
421
    """Stores the data, status, and block information of a sequence.
422

423
424
425
426
427
    The sequence is constructed from the
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
    or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    (for encoder-decoder) instance passed in through the `inputs`
    constructor argument.
428

429
430
    Args:
        seq_id: The ID of the sequence.
431
        inputs: The inputs of the sequence.
432
433
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
434
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
435
        lora_request: LoRA request.
436
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
437
438

    def __init__(
439
440
        self,
        seq_id: int,
441
        inputs: SingletonInputs,
442
443
444
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
445
446
    ) -> None:
        self.seq_id = seq_id
447
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
448
        self.block_size = block_size
449
        self.eos_token_id = eos_token_id
450
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
451

452
453
454
455
        self.data = SequenceData.from_seqs(
            self.prompt_token_ids,
            prompt_embeds=self.inputs["prompt_embeds"]
            if self.inputs["type"] == "embeds" else None)
456
        self.output_logprobs: SampleLogprobs = []
457
        self.output_text = ""
458

459
        self.status = SequenceStatus.WAITING
460
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
461

462
        # These are used to keep track of delta outputs
463
        self._last_output_token_ids_offset: int = 0
464
465
        self._last_output_text_offset: int = 0

466
467
468
469
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
470
        self.tokens: Optional[list[str]] = None
471

472
473
    @property
    def n_blocks(self) -> int:
474
        return (self.get_len() + self.block_size - 1) // self.block_size
475

476
    @property
477
    def prompt(self) -> Optional[str]:
478
479
        if self.inputs["type"] == "embeds":
            return None
480
        return self.inputs.get("prompt")
481

482
    @property
483
    def prompt_token_ids(self) -> list[int]:
484
485
        if self.inputs["type"] == "embeds":
            return [0] * len(self.inputs["prompt_embeds"])
486
        return self.inputs["prompt_token_ids"]
487

488
    @property
489
490
    def multi_modal_data(self) -> MultiModalKwargs:
        if self.inputs["type"] == "multimodal":
491
            return self.inputs["mm_kwargs"].get_data()
492

493
        return MultiModalKwargs()
494

495
496
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
497
498
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_placeholders"]
499

500
        return {}
501

502
503
504
505
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

506
507
508
509
510
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

511
512
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
513
514
515
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
516
517
518
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
519
520
521
522
523
524
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

525
526
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
527
528
529
530
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
531
532
533
534
535
536
537
538
539
540
541
542
543

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

544
545
546
        if num_new_tokens == 0:
            return []

547
        return self.data._cached_all_token_ids[-num_new_tokens:]
548

549
    def hash_of_block(self, logical_idx: int) -> int:
550
551
        # TODO This can produce incorrect hash when block size > prompt size

552
        # Compute the number of tokens in the sequence
553
554
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
555
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
556
557
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
558

559
560
561
562
563
564
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
565
        if self.lora_int_id == 0:
566
567
568
569
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
570
        return hash(self.lora_int_id)
571

572
573
574
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

575
576
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
577
        self.data.reset_state_for_recompute()
578

579
580
581
582
    def append_token_id(self,
                        token_id: int,
                        logprobs: dict[int, Logprob],
                        token_embed: Optional[torch.Tensor] = None) -> None:
583
584
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
585
586
        self.data.append_token_id(token_id, logprobs[token_id].logprob,
                                  token_embed)
587

Woosuk Kwon's avatar
Woosuk Kwon committed
588
    def get_len(self) -> int:
589
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
590

591
592
593
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

594
595
596
    def get_output_len(self) -> int:
        return self.data.get_output_len()

597
    def get_token_ids(self) -> list[int]:
598
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
599

600
    def get_prompt_token_ids(self) -> tuple[int, ...]:
601
602
        return self.data.get_prompt_token_ids()

603
    def get_last_token_id(self) -> int:
604
        return self.data.get_last_token_id()
605

606
    def get_output_token_ids(self) -> tuple[int, ...]:
607
        return self.data.get_output_token_ids()
608
609
610
611

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

612
613
614
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

615
616
617
618
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
619

620
621
622
623
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
624
625
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
626
627
628
629
630
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

631
632
633
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

634
635
636
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
637
    def __repr__(self) -> str:
638
639
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
640
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
641

Woosuk Kwon's avatar
Woosuk Kwon committed
642

643
644
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
645
646
647
648
649
650
651
652
653
654
655
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
656
class SequenceGroup:
657
658
659
660
661
662
663
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
664
        lora_request: LoRA request.
665
        pooling_params: The parameters used to generate the pooler
666
            for a pooling model.
667
        pooled_data: The extracted hidden states from a pooling model.
668
669
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
670
        trace_headers: OpenTelemetry trace headers.
671
        priority: User-defined priority of the request.
672
        draft_size: The number of speculative tokens plus one from the target
673
                    model; equal to max number of tokens a step can generate
674
                    for single-draft speculative decoding but larger than
675
                    that for multi-draft SD (currently not supported).
676
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
677

678
679
680
681
682
683
684
685
686
687
688
689
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
690
        self.request_id = request_id
691
        self.seqs = seqs
692
        self.first_seq = seqs[0]
693
        self.arrival_time = arrival_time
694
        self.is_single_seq = len(seqs) == 1
695
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
696

697
        self.sampling_params = sampling_params
698
699
700
701
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
702
                                      time_in_queue=None)
703
        self.last_token_latency = 0.0
704
        self.lora_request = lora_request
705
        self.prompt_logprobs: Optional[PromptLogprobs] = None
706
        self.state = SequenceGroupState()
707
        self.pooling_params = pooling_params
708
        self.pooled_data = pooled_data
709
        self.encoder_seq = encoder_seq
710
        self.trace_headers = trace_headers
711
        self.priority = priority
712

713
714
        self.cached_request_output = None

715
    @property
716
    def prompt(self) -> Optional[str]:
717
        return self.first_seq.prompt
718
719

    @property
720
    def prompt_token_ids(self) -> list[int]:
721
        return self.first_seq.prompt_token_ids
722

723
724
725
726
727
728
729
730
731
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
732
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
733
734
735
736
737
738
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

739
    @property
740
    def multi_modal_data(self) -> MultiModalKwargs:
741
742
743
744
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
745
        return MultiModalKwargs()
Woosuk Kwon's avatar
Woosuk Kwon committed
746

747
748
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
749
750
751
752
753
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
754

755
756
757
758
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

759
    def set_last_token_time(self, now: float) -> None:
760
        """Sets the last token time for Request level timings."""
761
762
763
764
765
        # If still in prefill phase, assertion fails.
        assert not self.is_prefill(), (
            "seq_group.set_last_token_time() should not be called "
            "if the seq_group is in prefill phase.")
        self.last_token_latency = now - self.metrics.last_token_time
766
        self.metrics.last_token_time = now
767
768
769
770
771
772
773

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
        assert not self.is_prefill(), (
            "seq_group.get_last_token_latency() should not be called "
            "if the seq_group is in prefill phase.")
        return self.last_token_latency
774

775
776
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
777
778
779
780
781
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
782
                and self.first_seq.get_output_len() == 1):
783
784
785
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
786
787
        """Sets the first scheduled time and time in queue for Request
        level timings."""
788
789
790
791
792
793
794
795
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

796
797
798
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
799
800
801
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
802

803
804
805
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
806
    ) -> list[Sequence]:
807
808
        if status is None:
            return self.seqs
809

810
811
812
813
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
814

815
816
817
818
819
820
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

821
    def get_finished_seqs(self) -> list[Sequence]:
822
823
824
825
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
826

827
828
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
829
830
831
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
832
833

    def get_num_uncomputed_tokens(self) -> int:
834
        num_uncomputed_tokens = 0
835
836
837
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
838
        return num_uncomputed_tokens
839

840
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
841
842
843
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
844
            return len(self.seqs)
845

846
847
848
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

849
        return len(self.get_seqs(status))
850

851
    def num_finished_seqs(self) -> int:
852
853
854
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
855

Woosuk Kwon's avatar
Woosuk Kwon committed
856
    def is_finished(self) -> bool:
857
858
859
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
860

861
    def is_prefill(self) -> bool:
862
        return self.first_seq.is_prefill()
863

Woosuk Kwon's avatar
Woosuk Kwon committed
864
    def __repr__(self) -> str:
865
866
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
867
                f"num_seqs={len(self.seqs)})")
868

869
870
871
872
    def uses_prompt_embeds(self) -> bool:
        """Returns True if the sequence group uses input embeds."""
        return any(seq.data.prompt_embeds is not None for seq in self.seqs)

873

874
875
876
877
878
879
880
881
882
883
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
884
    seq_data_delta: dict[int, SequenceDataDelta]
885
    request_id: str
886
    block_tables: dict[int, list[int]]
887
888
889
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
890
    computed_block_nums: Optional[list[int]] = None
891
892
893
894
895
896
897
898
899
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
900
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
901

902
    Attributes:
903
904
905
906
907
908
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
909
910
911
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
912
        pooling_params: Pooling parameters.
913
        lora_request: LoRA request.
914
915
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
916
        state: Internal state tied to this sequence group.
917
        token_type_ids: Token type IDs.
918
        multi_modal_data: Multi modal data.
919
        multi_modal_placeholders: Multi modal placeholders.
920
        encoder_seq_data: Optional sequence data for encoder prompt
921
                          (SequenceGroup.encoder_seq). Should be None
922
923
924
925
926
927
928
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
929
    """
930

931
932
    request_id: str
    is_prompt: bool
933
    seq_data: dict[int, SequenceData]
934
    sampling_params: Optional[SamplingParams]
935
    block_tables: dict[int, list[int]]
936
937
938
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
939
    computed_block_nums: Optional[list[int]] = None
940
941
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
942
    multi_modal_data: Optional[MultiModalKwargs] = None
943
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
944
    encoder_seq_data: Optional[SequenceData] = None
945
    cross_block_table: Optional[list[int]] = None
946
947
948
949
950
951
952
953
954
955
956
957
958
959
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
960
            else:
961
                self.token_chunk_size = 1
962

963
964
965
966
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

967
968
969
970
971
972
973
974
975
976
977
978
979
980
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

981
982
983
984
985
986
987
988
989
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
990

991
    def finish_step(self) -> None:
992
        assert self.state is not None
993
994
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
995
996
        self.state.current_step += 1

997

998
999
1000
1001
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1002
1003
    """The model output associated with a sequence.

1004
    Attributes:
1005
1006
1007
1008
1009
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
1010
        output_embed: Optional output embedding tensor.
1011
    """
1012
1013
    parent_seq_id: int
    output_token: int
1014
    logprobs: dict[int, Logprob]
1015
    output_embed: Optional[torch.Tensor] = None
1016
1017

    def __repr__(self) -> str:
1018
1019
        output_embed_shape = \
            self.output_embed.shape if self.output_embed is not None else None
Zhuohan Li's avatar
Zhuohan Li committed
1020
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1021
                f"output_token={self.output_token}, "
1022
                f"output_embed.shape={output_embed_shape}, "
1023
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1024

1025
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1026
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1027
            raise NotImplementedError()
1028
1029
1030
1031
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1032
1033


1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1046
1047
1048
1049
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1050
    """The model output associated with a completion sequence group."""
1051
    __metaclass__ = SequenceGroupOutput
1052
    samples: list[SequenceOutput]
1053
1054
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1055
    step_index: Optional[int] = 0
1056
1057

    def __repr__(self) -> str:
1058
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1059
1060
                f"prompt_logprobs={self.prompt_logprobs})")

1061
    def __eq__(self, other: object) -> bool:
1062
        if not isinstance(other, CompletionSequenceGroupOutput):
1063
1064
1065
1066
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1067

1068
class PoolingSequenceGroupOutput(
1069
1070
1071
1072
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1073
    """The model output associated with a pooling sequence group."""
1074
    __metaclass__ = SequenceGroupOutput
1075
1076
1077
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1078

1079
1080
1081
1082
    def get_data_nbytes(self) -> int:
        data: torch.Tensor = self.data
        return data.nbytes

1083
    def __repr__(self) -> str:
1084
        return f"PoolingSequenceGroupOutput(data={self.data}"
1085
1086

    def __eq__(self, other: object) -> bool:
1087
        if not isinstance(other, PoolingSequenceGroupOutput):
1088
            raise NotImplementedError()
1089
        return self.data == other.data
1090
1091


1092
1093
1094
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1095
1096
1097
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
1098
    
1099
    Each stage also needs to handle its own kv_connector_output.
1100
1101
    """

1102
    tensors: dict[str, torch.Tensor]
1103
    kv_connector_output: Optional[KVConnectorOutput]
1104

1105
1106
1107
1108
1109
1110
1111
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1112
1113
1114
1115
1116
1117
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1118
    def __setitem__(self, key: str, value: torch.Tensor):
1119
1120
        self.tensors[key] = value

1121
1122
1123
    def items(self):
        return self.tensors.items()

1124
1125
1126
1127
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
1128
1129
1130
1131
1132
1133
1134
        if not isinstance(other, self.__class__):
            return False
        if self.tensors.keys() != other.tensors.keys():
            return False
        return all(
            torch.equal(self.tensors[k], other.tensors[k])
            for k in self.tensors)
1135
1136
1137
1138
1139

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1140
1141
1142
1143
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1144
    """The output from a pooling operation in the pooling model."""
1145
    outputs: list[PoolingSequenceGroupOutput]
1146

1147
1148
1149
    def get_data_nbytes(self) -> int:
        return sum(o.get_data_nbytes() for o in self.outputs)

1150
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1151
1152
        return self.outputs[idx]

1153
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1164
def get_all_seq_ids(
1165
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1166
1167
1168
1169
1170
1171
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1172
def get_all_seq_ids_and_request_ids(
1173
1174
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1175
1176
1177
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1178
1179
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1180
1181
1182
1183
1184
1185
1186
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1187
1188
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1189
1190
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1191
    the target model to the proposer model.
1192
1193
1194

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1195
    # Scorer hidden states. For prefill step, it is used for hidden states of
1196
    # all tokens, whereas for decode step, it is used for last accepted tokens.
1197
    hidden_states: torch.Tensor
1198
    # The sequence group metadata list. Only needed for decode step.
1199
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1200
1201
1202
1203
1204
1205
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1206
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1207
1208

    def __post_init__(self):
1209
1210
1211
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1212
1213

    @property
1214
    def seq_ids(self) -> list[int]:
1215
        return self._seq_ids
1216

1217
1218
    def update(self,
               hidden_states: torch.Tensor,
1219
               seq_group_metadata_list: list[SequenceGroupMetadata],
1220
1221
1222
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1223
        assert len(seq_group_metadata_list) == len(hidden_states)
1224
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1225
1226
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1227
1228
1229
1230
1231
1232
1233
1234
1235
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1236
    def prune(self,
1237
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1238
1239
1240
1241
1242
1243
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1244
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1245
1246
        # Only keep sequence IDs that exist in self._seq_ids
        seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
1247
        if seq_ids != self._seq_ids:
1248
            # Batch contents changed - prune removed sequences.
1249
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1250
            self.hidden_states = self.hidden_states[index]
1251
1252
1253
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1254
            self._seq_ids = seq_ids
1255

1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1274

1275
1276
1277
1278
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1279
1280
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1281
    # The sequence group metadata list.
1282
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1283
                                        SequenceGroupMetadataDelta]]
1284
    # Blocks to swap in. List of CPU -> GPU block number.
1285
    blocks_to_swap_in: list[tuple[int,
1286
                                  int]] = msgspec.field(default_factory=list)
1287
    # Blocks to swap out. List of GPU -> CPU block number.
1288
    blocks_to_swap_out: list[tuple[int,
1289
                                   int]] = msgspec.field(default_factory=list)
1290
    # Blocks to copy. Source to dest block.
1291
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1292
1293
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1294
1295
1296
1297
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1298
1299
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1300
1301
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1302
    # Finished request ids since last step.
1303
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1304
1305
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1306
1307
    # Async callback
    async_callback: Optional[Callable] = None
1308
1309
1310
1311
1312
1313
1314

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1315
        assert first_seq_group.state is not None
1316
        return first_seq_group.state.remaining_steps == 1
1317
1318
1319
1320
1321
1322

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1323
1324
1325
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1326
1327

    def clone(
1328
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1329
                                                  SequenceGroupMetadataDelta]]
1330
1331
1332
1333
1334
1335
1336
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1337
            virtual_engine=self.virtual_engine,
1338
1339
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1340
            previous_hidden_states=self.previous_hidden_states,
1341
            num_steps=self.num_steps,
1342
1343
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1344
            if self.last_sampled_token_ids is not None else None,
1345
            async_callback=self.async_callback)
1346
1347
1348
1349
1350
1351
1352
1353
1354


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1355
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1356
1357

    # seq ids to be finished
1358
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1359
1360

    # seq id to finished sequences
1361
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1398
            params = original_params.clone()
1399
1400
1401
            params.n = 1
            if params.seed is not None:
                params.seed += i
1402
            seq_group = engine._add_processed_request(
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1422
            pooled_data=seq_group.pooled_data,
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1435
1436
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1437
        if self.streaming:
1438
1439
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1440
1441
1442
1443
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1444
        # when the last sequences finishes, and then return None for the
1445
        # rest of the time
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None