"vscode:/vscode.git/clone" did not exist on "9aabf7e7691d71658116ac8e122d360b1bd2fd4e"
sequence.py 60.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sequence and its related classes."""
4
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
from abc import ABC, abstractmethod
7
from array import array
8
from collections import defaultdict
9
10
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
11
from dataclasses import dataclass, field
12
from functools import reduce
13
from typing import Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
import msgspec
16
17
import torch

18
from vllm.inputs import SingletonInputs
19
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
21
from vllm.pooling_params import PoolingParams
22
from vllm.prompt_adapter.request import PromptAdapterRequest
23
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
26

27
28
VLLM_INVALID_TOKEN_ID = -1

29

30
def array_full(token_id: int, count: int):
31
    """[`array`][] equivalent of [numpy.full][]."""
32
33
34
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


35
36
37
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
38
39
@dataclass
class Logprob:
40
41
42
43
44
45
46
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
47
    logprob: float
48
    rank: Optional[int] = None
49
50
51
    decoded_token: Optional[str] = None


52
53
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
54
PromptLogprobs = list[Optional[dict[int, Logprob]]]
55
# {token_id -> logprob} for each sequence group.
56
SampleLogprobs = list[dict[int, Logprob]]
57

Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
class SequenceStatus(enum.IntEnum):
60
    """Status of a sequence."""
61
62
63
64
65
66
67
68
69
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
73
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
74
75
76
77
78
79
80

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
81
82
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
83
        elif status == SequenceStatus.FINISHED_IGNORED:
84
85
86
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
87
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
88
89
90
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
91

92

93
94
95
96
97
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


98
99
100
101
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

102
    Attributes:
103
104
105
106
107
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
108
109
110
111
112
113
114
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
115
116
117
118
119
120
121
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
122
123
124
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
125
126


127
128
129
130
131
132
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
133
    new_output_token_ids: list[int]
134
135
136
137
138
139
140
141
142
143
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
144
145
146
147
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
148
149
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
150
151
152
153
154
155

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
156
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
157
158
159
160
161
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

162
163
164
    _prompt_embeds: Optional[torch.Tensor] = None
    _output_embeds: Optional[torch.Tensor] = None

165
166
    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
167
    _prompt_token_ids_tuple: tuple[int,
168
169
170
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
171
172
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
173
    _stage: SequenceStage = SequenceStage.PREFILL
174
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
175
    _cached_all_token_embeds: Optional[torch.Tensor] = None
176
177
178

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
179
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
180

181
182
183
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

184
    @staticmethod
185
    def from_prompt_token_counts(
186
            *token_counts: tuple[int, int]) -> "SequenceData":
187
        """
188
189
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        by concatenating prompt token sequences.
190
191

        Each tuple represents one token sequence, expressed in the form
192
        `(token_id, count)`.
193
        """
194
        if len(token_counts) == 0:
195
196
            return SequenceData.from_seqs([])

197
198
199
200
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
201

202
        return SequenceData(prompt_token_ids_arr)
203
204
205
206
207

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
208
209
        *,
        prompt_embeds: Optional[torch.Tensor] = None,
210
    ) -> "SequenceData":
211
        """
212
213
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        from prompt and output token sequences.
214
        """
215
216
217
218
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
219
220
            return SequenceData(prompt_token_ids_arr,
                                _prompt_embeds=prompt_embeds)
221
222
223
224
225

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
226
227
                            _output_token_ids=output_token_ids_arr,
                            _prompt_embeds=prompt_embeds)
228

229
230
231
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
232
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
233
            self._prompt_token_ids)
234
        self._update_cached_all_tokens()
235
236
        if self._prompt_embeds is not None:
            self._update_cached_all_token_embeds()
237
238

    def _update_cached_all_tokens(self):
239
240
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
241
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
242
                                                     self._output_token_ids)
243

244
245
246
247
248
249
250
    def _update_cached_all_token_embeds(self):
        assert isinstance(self._prompt_embeds, torch.Tensor)
        self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
        if self._output_embeds is not None:
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds, self._output_embeds), dim=0)

251
252
253
254
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

255
    @property
256
    def prompt_token_ids(self) -> tuple[int, ...]:
257
258
259
260
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
261
        raise NotImplementedError
262

263
264
    @property
    def prompt_token_ids_array(self) -> array:
265
266
267
268
269
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
270
271
        return self._prompt_token_ids

272
    @property
273
    def output_token_ids(self) -> tuple[int, ...]:
274
275
276
        return tuple(self._output_token_ids)

    @output_token_ids.setter
277
278
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
279
280
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
281
282
        self._update_cached_all_tokens()

283
284
285
286
287
288
289
290
291
    @property
    def output_embeds(self) -> Optional[torch.Tensor]:
        return self._output_embeds

    @output_embeds.setter
    def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
        self._output_token_embeds = new_output_token_embeds
        self._update_cached_all_token_embeds()

292
293
    @property
    def output_token_ids_array(self) -> array:
294
295
296
297
298
299
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
300
301
        return self._output_token_ids

302
303
304
305
306
307
308
309
310
    @property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        return self._prompt_embeds

    @prompt_embeds.setter
    def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
        self._prompt_embeds = prompt_embeds
        self._update_cached_all_token_embeds()

311
312
313
314
315
316
317
318
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

319
320
321
322
    def append_token_id(self,
                        token_id: int,
                        logprob: float,
                        token_embed: Optional[torch.Tensor] = None) -> None:
323
        self._output_token_ids.append(token_id)
324
        self._new_appended_tokens.append(token_id)
325
        self._cached_all_token_ids.append(token_id)
326
        self._cumulative_logprob += logprob
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        if token_embed is not None:
            # Do not pass in with batch or sequence dimensions
            assert token_embed.ndim == 1
            token_embed = token_embed.detach().cpu().unsqueeze(0)
            if self._output_embeds is None:
                self._output_embeds = token_embed
            else:
                self._output_embeds = torch.cat(
                    (self._output_embeds, token_embed), dim=0)
            assert self._cached_all_token_embeds is not None
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds,
                 token_embed.to(device=self._cached_all_token_embeds.device)),
                dim=0)
341
342

    def get_len(self) -> int:
343
        return len(self._output_token_ids) + len(self._prompt_token_ids)
344

345
    def get_prompt_len(self) -> int:
346
        return len(self._prompt_token_ids)
347

348
    def get_output_len(self) -> int:
349
        return len(self._output_token_ids)
350

351
    def get_token_ids(self) -> list[int]:
352
        return self._cached_all_token_ids
353

354
355
356
    def get_token_embeddings(self) -> Optional[torch.Tensor]:
        return self._cached_all_token_embeds

357
358
    def get_prefix_token_ids(
            self, num_tokens: int
359
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
360
        """Get prefix tokens, and make the return value hashable"""
361
        prompt_length = self.get_prompt_len()
362
363
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
364
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
365
366
367
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

368
369
370
371
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

372
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
373
374
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
375
376
377
378
379
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
380

381
382
383
384
385
386
387
388
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

389
    def reset_state_for_recompute(self) -> None:
390
391
392
393
394
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
395
        self._stage = SequenceStage.PREFILL
396
        self._new_appended_tokens = []
397
398

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
399
        """Return the number of prefill tokens that are not computed."""
400
401
402
403
404
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

405
    def get_last_token_id(self) -> int:
406
407
408
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
409

410
    def get_prompt_token_ids(self) -> tuple[int, ...]:
411
412
        return self.prompt_token_ids

413
    def get_output_token_ids(self) -> tuple[int, ...]:
414
415
        return self.output_token_ids

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

431
432
433
434
    @property
    def stage(self) -> SequenceStage:
        return self._stage

435
436
    def __repr__(self) -> str:
        return (f"SequenceData("
437
                f"prompt_token_ids={self._prompt_token_ids}, "
438
439
                f"prompt_embeds.shape="
                f"{getattr(self._prompt_embeds, 'shape', None)}, "
440
441
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
442
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
443
444


Woosuk Kwon's avatar
Woosuk Kwon committed
445
class Sequence:
446
    """Stores the data, status, and block information of a sequence.
447

448
449
450
451
452
    The sequence is constructed from the
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
    or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    (for encoder-decoder) instance passed in through the `inputs`
    constructor argument.
453

454
455
    Args:
        seq_id: The ID of the sequence.
456
        inputs: The inputs of the sequence.
457
458
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
459
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
460
        lora_request: LoRA request.
461
        prompt_adapter_request: Prompt Adapter request.
462
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
463
464

    def __init__(
465
466
        self,
        seq_id: int,
467
        inputs: SingletonInputs,
468
469
470
471
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
472
473
    ) -> None:
        self.seq_id = seq_id
474
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
475
        self.block_size = block_size
476
        self.eos_token_id = eos_token_id
477
        self.lora_request = lora_request
478
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
479

480
481
482
483
        self.data = SequenceData.from_seqs(
            self.prompt_token_ids,
            prompt_embeds=self.inputs["prompt_embeds"]
            if self.inputs["type"] == "embeds" else None)
484
        self.output_logprobs: SampleLogprobs = []
485
        self.output_text = ""
486

487
        self.status = SequenceStatus.WAITING
488
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
489

490
        # These are used to keep track of delta outputs
491
        self._last_output_token_ids_offset: int = 0
492
493
        self._last_output_text_offset: int = 0

494
495
496
497
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
498
        self.tokens: Optional[list[str]] = None
499

500
501
    @property
    def n_blocks(self) -> int:
502
        return (self.get_len() + self.block_size - 1) // self.block_size
503

504
    @property
505
    def prompt(self) -> Optional[str]:
506
507
        if self.inputs["type"] == "embeds":
            return None
508
        return self.inputs.get("prompt")
509

510
    @property
511
    def prompt_token_ids(self) -> list[int]:
512
513
        if self.inputs["type"] == "embeds":
            return [0] * len(self.inputs["prompt_embeds"])
514
        return self.inputs["prompt_token_ids"]
515

516
    @property
517
    def token_type_ids(self) -> list[int]:
518
519
        if self.inputs["type"] == "embeds":
            return []
520
        return self.inputs.get("token_type_ids", [])
521

522
    @property
523
524
525
526
527
    def multi_modal_data(self) -> MultiModalKwargs:
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_kwargs"]

        return MultiModalKwargs({})
528

529
530
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
531
532
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_placeholders"]
533

534
        return {}
535

536
537
538
539
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

540
541
542
543
544
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

545
546
547
548
549
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

550
551
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
552
553
554
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
555
556
557
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
558
559
560
561
562
563
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

564
565
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
566
567
568
569
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
570
571
572
573
574
575
576
577
578
579
580
581
582

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

583
584
585
        if num_new_tokens == 0:
            return []

586
        return self.data._cached_all_token_ids[-num_new_tokens:]
587

588
    def hash_of_block(self, logical_idx: int) -> int:
589
590
        # TODO This can produce incorrect hash when block size > prompt size

591
        # Compute the number of tokens in the sequence
592
593
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
594
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
595
596
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
597

598
599
600
601
602
603
604
605
606
607
608
609
610
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
        if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
        return hash((self.prompt_adapter_id, self.lora_int_id))

611
612
613
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

614
615
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
616
        self.data.reset_state_for_recompute()
617

618
619
620
621
    def append_token_id(self,
                        token_id: int,
                        logprobs: dict[int, Logprob],
                        token_embed: Optional[torch.Tensor] = None) -> None:
622
623
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
624
625
        self.data.append_token_id(token_id, logprobs[token_id].logprob,
                                  token_embed)
626

Woosuk Kwon's avatar
Woosuk Kwon committed
627
    def get_len(self) -> int:
628
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
629

630
631
632
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

633
634
635
    def get_output_len(self) -> int:
        return self.data.get_output_len()

636
    def get_token_ids(self) -> list[int]:
637
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
638

639
    def get_prompt_token_ids(self) -> tuple[int, ...]:
640
641
        return self.data.get_prompt_token_ids()

642
    def get_last_token_id(self) -> int:
643
        return self.data.get_last_token_id()
644

645
    def get_output_token_ids(self) -> tuple[int, ...]:
646
        return self.data.get_output_token_ids()
647
648
649
650

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

651
652
653
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

654
655
656
657
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
658

659
660
661
662
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
663
664
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
665
666
667
668
669
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

670
671
672
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

673
674
675
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
676
    def __repr__(self) -> str:
677
678
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
679
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
680

Woosuk Kwon's avatar
Woosuk Kwon committed
681

682
683
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
684
685
686
687
688
689
690
691
692
693
694
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
695
class SequenceGroup:
696
697
698
699
700
701
702
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
703
        lora_request: LoRA request.
704
        pooling_params: The parameters used to generate the pooler
705
            for a pooling model.
706
        pooled_data: The extracted hidden states from a pooling model.
707
708
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
709
        trace_headers: OpenTelemetry trace headers.
710
        prompt_adapter_request: Prompt Adapter request.
711
        priority: User-defined priority of the request.
712
        draft_size: The number of speculative tokens plus one from the target
713
                    model; equal to max number of tokens a step can generate
714
                    for single-draft speculative decoding but larger than
715
                    that for multi-draft SD (currently not supported).
716
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
717

718
719
720
721
722
723
724
725
726
727
728
729
730
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 prompt_adapter_request: Optional[PromptAdapterRequest] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
731
        self.request_id = request_id
732
        self.seqs = seqs
733
        self.first_seq = seqs[0]
734
        self.arrival_time = arrival_time
735
        self.is_single_seq = len(seqs) == 1
736
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
737

738
        self.sampling_params = sampling_params
739
740
741
742
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
743
                                      time_in_queue=None)
744
        self.last_token_latency = 0.0
745
        self.lora_request = lora_request
746
        self.prompt_logprobs: Optional[PromptLogprobs] = None
747
        self.state = SequenceGroupState()
748
        self.pooling_params = pooling_params
749
        self.pooled_data = pooled_data
750
        self.prompt_adapter_request = prompt_adapter_request
751
        self.encoder_seq = encoder_seq
752
        self.trace_headers = trace_headers
753
        self.priority = priority
754

755
756
        self.cached_request_output = None

757
    @property
758
    def prompt(self) -> Optional[str]:
759
        return self.first_seq.prompt
760
761

    @property
762
    def prompt_token_ids(self) -> list[int]:
763
        return self.first_seq.prompt_token_ids
764

765
766
767
768
769
770
771
772
773
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
774
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
775
776
777
778
779
780
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

781
    @property
782
    def token_type_ids(self) -> Optional[list[int]]:
783
784
        return self.first_seq.token_type_ids

785
    @property
786
    def multi_modal_data(self) -> MultiModalKwargs:
787
788
789
790
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
791
        return MultiModalKwargs({})
Woosuk Kwon's avatar
Woosuk Kwon committed
792

793
794
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
795
796
797
798
799
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
800

801
802
803
804
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

805
806
807
808
809
810
811
812
813
814
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

815
816
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
817
818
        self.state.current_step = 0

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

844
    def set_last_token_time(self, now: float) -> None:
845
        """Sets the last token time for Request level timings."""
846
847
848
849
850
        # If still in prefill phase, assertion fails.
        assert not self.is_prefill(), (
            "seq_group.set_last_token_time() should not be called "
            "if the seq_group is in prefill phase.")
        self.last_token_latency = now - self.metrics.last_token_time
851
        self.metrics.last_token_time = now
852
853
854
855
856
857
858

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
        assert not self.is_prefill(), (
            "seq_group.get_last_token_latency() should not be called "
            "if the seq_group is in prefill phase.")
        return self.last_token_latency
859

860
861
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
862
863
864
865
866
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
867
                and self.first_seq.get_output_len() == 1):
868
869
870
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
871
872
        """Sets the first scheduled time and time in queue for Request
        level timings."""
873
874
875
876
877
878
879
880
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

881
882
883
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
884
885
886
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
887

888
889
890
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
891
    ) -> list[Sequence]:
892
893
        if status is None:
            return self.seqs
894

895
896
897
898
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
899

900
901
902
903
904
905
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

906
    def get_finished_seqs(self) -> list[Sequence]:
907
908
909
910
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
911

912
913
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
914
915
916
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
917
918

    def get_num_uncomputed_tokens(self) -> int:
919
        num_uncomputed_tokens = 0
920
921
922
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
923
        return num_uncomputed_tokens
924

925
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
926
927
928
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
929
            return len(self.seqs)
930

931
932
933
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

934
        return len(self.get_seqs(status))
935

936
    def num_finished_seqs(self) -> int:
937
938
939
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
940

Woosuk Kwon's avatar
Woosuk Kwon committed
941
    def is_finished(self) -> bool:
942
943
944
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
945

946
    def is_prefill(self) -> bool:
947
        return self.first_seq.is_prefill()
948

Woosuk Kwon's avatar
Woosuk Kwon committed
949
    def __repr__(self) -> str:
950
951
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
952
                f"num_seqs={len(self.seqs)})")
953

954
955
956
957
    def uses_prompt_embeds(self) -> bool:
        """Returns True if the sequence group uses input embeds."""
        return any(seq.data.prompt_embeds is not None for seq in self.seqs)

958

959
960
961
962
963
964
965
966
967
968
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
969
    seq_data_delta: dict[int, SequenceDataDelta]
970
    request_id: str
971
    block_tables: dict[int, list[int]]
972
973
974
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
975
    computed_block_nums: Optional[list[int]] = None
976
977
978
979
980
981
982
983
984
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
985
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
986
987
988
989
990
991
992
993

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
994
995
996
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
997
998
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
999
        lora_request: LoRA request.
1000
1001
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
1002
        state: Internal state tied to this sequence group.
1003
        multi_modal_data: Multi modal data.
1004
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
1005
        encoder_seq_data: Optional sequence data for encoder prompt
1006
                          (SequenceGroup.encoder_seq). Should be None
1007
1008
1009
1010
1011
1012
1013
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
1014
        prompt_adapter_request: Prompt Adapter request.
1015
    """
1016

1017
1018
    request_id: str
    is_prompt: bool
1019
    seq_data: dict[int, SequenceData]
1020
    sampling_params: Optional[SamplingParams]
1021
    block_tables: dict[int, list[int]]
1022
1023
1024
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
1025
    computed_block_nums: Optional[list[int]] = None
1026
1027
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
1028
    token_type_ids: Optional[list[int]] = None
1029
    multi_modal_data: Optional[MultiModalKwargs] = None
1030
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
1031
    encoder_seq_data: Optional[SequenceData] = None
1032
    cross_block_table: Optional[list[int]] = None
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
1048
            else:
1049
                self.token_chunk_size = 1
1050

1051
1052
1053
1054
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1055
    @property
1056
1057
1058
1059
1060
1061
1062
1063
1064
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1079
1080
1081
1082
1083
1084
1085
1086
1087
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1088

1089
    def finish_step(self) -> None:
1090
        assert self.state is not None
1091
1092
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1093
1094
        self.state.current_step += 1

1095

1096
1097
1098
1099
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1100
1101
1102
1103
1104
1105
1106
1107
1108
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1109
1110
    parent_seq_id: int
    output_token: int
1111
    logprobs: dict[int, Logprob]
1112
    output_embed: Optional[torch.Tensor] = None
1113
1114

    def __repr__(self) -> str:
1115
1116
        output_embed_shape = \
            self.output_embed.shape if self.output_embed is not None else None
Zhuohan Li's avatar
Zhuohan Li committed
1117
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1118
                f"output_token={self.output_token}, "
1119
                f"output_embed.shape={output_embed_shape}, "
1120
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1121

1122
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1123
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1124
            raise NotImplementedError()
1125
1126
1127
1128
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1129
1130


1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1143
1144
1145
1146
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1147
    """The model output associated with a completion sequence group."""
1148
    __metaclass__ = SequenceGroupOutput
1149
    samples: list[SequenceOutput]
1150
1151
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1152
    step_index: Optional[int] = 0
1153
1154

    def __repr__(self) -> str:
1155
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1156
1157
                f"prompt_logprobs={self.prompt_logprobs})")

1158
    def __eq__(self, other: object) -> bool:
1159
        if not isinstance(other, CompletionSequenceGroupOutput):
1160
1161
1162
1163
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1164

1165
class PoolingSequenceGroupOutput(
1166
1167
1168
1169
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1170
    """The model output associated with a pooling sequence group."""
1171
    __metaclass__ = SequenceGroupOutput
1172
1173
1174
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1175

1176
1177
1178
1179
    def get_data_nbytes(self) -> int:
        data: torch.Tensor = self.data
        return data.nbytes

1180
    def __repr__(self) -> str:
1181
        return f"PoolingSequenceGroupOutput(data={self.data}"
1182
1183

    def __eq__(self, other: object) -> bool:
1184
        if not isinstance(other, PoolingSequenceGroupOutput):
1185
            raise NotImplementedError()
1186
        return self.data == other.data
1187
1188


1189
1190
1191
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1192
1193
1194
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
1195
1196
1197
    
    Each stage also needs to handle its own finished_sending and 
    finished_recving in case of kv transfer.
1198
1199
    """

1200
    tensors: dict[str, torch.Tensor]
1201
1202
1203
    # [req_ids]
    finished_sending: Optional[set[str]] = None
    finished_recving: Optional[set[str]] = None
1204

1205
1206
1207
1208
1209
1210
1211
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1212
1213
1214
1215
1216
1217
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1218
    def __setitem__(self, key: str, value: torch.Tensor):
1219
1220
        self.tensors[key] = value

1221
1222
1223
    def items(self):
        return self.tensors.items()

1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1234
1235
1236
1237
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1238
    """The output from a pooling operation in the pooling model."""
1239
    outputs: list[PoolingSequenceGroupOutput]
1240

1241
1242
1243
    def get_data_nbytes(self) -> int:
        return sum(o.get_data_nbytes() for o in self.outputs)

1244
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1245
1246
        return self.outputs[idx]

1247
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1258
def get_all_seq_ids(
1259
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1260
1261
1262
1263
1264
1265
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1266
def get_all_seq_ids_and_request_ids(
1267
1268
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1269
1270
1271
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1272
1273
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1274
1275
1276
1277
1278
1279
1280
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1281
1282
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1283
1284
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1285
    the target model to the proposer model.
1286
1287
1288

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1289
1290
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1291
    hidden_states: torch.Tensor
1292
    # The sequence group metadata list. Only needed for decode step.
1293
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1294
1295
1296
1297
1298
1299
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1300
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1301
1302

    def __post_init__(self):
1303
1304
1305
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1306
1307

    @property
1308
    def seq_ids(self) -> list[int]:
1309
        return self._seq_ids
1310

1311
1312
    def update(self,
               hidden_states: torch.Tensor,
1313
               seq_group_metadata_list: list[SequenceGroupMetadata],
1314
1315
1316
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1317
        assert len(seq_group_metadata_list) == len(hidden_states)
1318
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1319
1320
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1321
1322
1323
1324
1325
1326
1327
1328
1329
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1330
    def prune(self,
1331
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1332
1333
1334
1335
1336
1337
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1338
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1339
1340
        # Only keep sequence IDs that exist in self._seq_ids
        seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
1341
        if seq_ids != self._seq_ids:
1342
            # Batch contents changed - prune removed sequences.
1343
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1344
            self.hidden_states = self.hidden_states[index]
1345
1346
1347
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1348
            self._seq_ids = seq_ids
1349

1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1368

1369
1370
1371
1372
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1373
1374
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1375
    # The sequence group metadata list.
1376
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1377
                                        SequenceGroupMetadataDelta]]
1378
    # Blocks to swap in. List of CPU -> GPU block number.
1379
    blocks_to_swap_in: list[tuple[int,
1380
                                  int]] = msgspec.field(default_factory=list)
1381
    # Blocks to swap out. List of GPU -> CPU block number.
1382
    blocks_to_swap_out: list[tuple[int,
1383
                                   int]] = msgspec.field(default_factory=list)
1384
    # Blocks to copy. Source to dest block.
1385
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1386
1387
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1388
1389
1390
1391
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1392
1393
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1394
1395
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1396
    # Finished request ids since last step.
1397
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1398
1399
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1400
1401
    # Async callback
    async_callback: Optional[Callable] = None
1402
1403
1404
1405
1406
1407
1408

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1409
        assert first_seq_group.state is not None
1410
1411
1412
1413
1414
1415
1416
1417
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1418
        assert first_seq_group.state is not None
1419
        return first_seq_group.state.remaining_steps == 1
1420
1421
1422
1423
1424
1425

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1426
1427
1428
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1429
1430

    def clone(
1431
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1432
                                                  SequenceGroupMetadataDelta]]
1433
1434
1435
1436
1437
1438
1439
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1440
            virtual_engine=self.virtual_engine,
1441
1442
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1443
            previous_hidden_states=self.previous_hidden_states,
1444
            num_steps=self.num_steps,
1445
1446
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1447
            if self.last_sampled_token_ids is not None else None,
1448
            async_callback=self.async_callback)
1449
1450
1451
1452
1453
1454
1455
1456
1457


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1458
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1459
1460

    # seq ids to be finished
1461
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1462
1463

    # seq id to finished sequences
1464
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1501
            params = original_params.clone()
1502
1503
1504
            params.n = 1
            if params.seed is not None:
                params.seed += i
1505
            seq_group = engine._add_processed_request(
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1525
            pooled_data=seq_group.pooled_data,
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            prompt_adapter_request=seq_group.prompt_adapter_request,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1539
1540
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1541
        if self.streaming:
1542
1543
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1544
1545
1546
1547
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1548
        # when the last sequences finishes, and then return None for the
1549
        # rest of the time
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None