"vscode:/vscode.git/clone" did not exist on "b7215de2c5fcdf8af96cf941556d63934ea8f353"
sequence.py 57.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sequence and its related classes."""
4
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
from abc import ABC, abstractmethod
7
from array import array
8
from collections import defaultdict
9
10
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
11
from dataclasses import dataclass, field
12
from functools import reduce
13
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
import msgspec
16
17
import torch

18
from vllm.inputs import SingletonInputs
19
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
21
from vllm.pooling_params import PoolingParams
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
if TYPE_CHECKING:
25
    from vllm.multimodal.inputs import NestedTensors
26
27
28
    from vllm.v1.worker.kv_connector_model_runner_mixin import (
        KVConnectorOutput)

29
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
30

31
32
VLLM_INVALID_TOKEN_ID = -1

33

34
def array_full(token_id: int, count: int):
35
    """[`array`][] equivalent of [numpy.full][]."""
36
37
38
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


39
40
41
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
42
43
@dataclass
class Logprob:
44
45
46
47
48
49
50
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
51
    logprob: float
52
    rank: Optional[int] = None
53
54
55
    decoded_token: Optional[str] = None


56
57
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
58
PromptLogprobs = list[Optional[dict[int, Logprob]]]
59
# {token_id -> logprob} for each sequence group.
60
SampleLogprobs = list[dict[int, Logprob]]
61

Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
class SequenceStatus(enum.IntEnum):
64
    """Status of a sequence."""
65
66
67
68
69
70
71
72
73
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
74
75
76

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
77
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
78
79
80
81
82
83
84

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
85
86
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
87
        elif status == SequenceStatus.FINISHED_IGNORED:
88
89
90
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
91
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
92
93
94
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
95

96

97
98
99
100
101
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


102
103
104
105
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

106
    Attributes:
107
108
109
110
111
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
112
113
114
115
116
117
118
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
119
120
121
122
123
124
125
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
126
127
128
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
129
130


131
132
133
134
135
136
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
137
    new_output_token_ids: list[int]
138
139
140
141
142
143
144
145
146
147
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
148
149
150
151
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
152
153
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
154
155
156
157
158
159

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
160
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
161
162
163
164
165
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

166
167
168
    _prompt_embeds: Optional[torch.Tensor] = None
    _output_embeds: Optional[torch.Tensor] = None

169
170
    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
171
    _prompt_token_ids_tuple: tuple[int,
172
173
174
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
175
176
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
177
    _stage: SequenceStage = SequenceStage.PREFILL
178
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
179
    _cached_all_token_embeds: Optional[torch.Tensor] = None
180
181
182

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
183
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
184

185
186
187
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

188
    @staticmethod
189
    def from_prompt_token_counts(
190
            *token_counts: tuple[int, int]) -> "SequenceData":
191
        """
192
193
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        by concatenating prompt token sequences.
194
195

        Each tuple represents one token sequence, expressed in the form
196
        `(token_id, count)`.
197
        """
198
        if len(token_counts) == 0:
199
200
            return SequenceData.from_seqs([])

201
202
203
204
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
205

206
        return SequenceData(prompt_token_ids_arr)
207
208
209
210
211

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
212
213
        *,
        prompt_embeds: Optional[torch.Tensor] = None,
214
    ) -> "SequenceData":
215
        """
216
217
        Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
        from prompt and output token sequences.
218
        """
219
220
221
222
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
223
224
            return SequenceData(prompt_token_ids_arr,
                                _prompt_embeds=prompt_embeds)
225
226
227
228
229

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
230
231
                            _output_token_ids=output_token_ids_arr,
                            _prompt_embeds=prompt_embeds)
232

233
234
235
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
236
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
237
            self._prompt_token_ids)
238
        self._update_cached_all_tokens()
239
240
        if self._prompt_embeds is not None:
            self._update_cached_all_token_embeds()
241
242

    def _update_cached_all_tokens(self):
243
244
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
245
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
246
                                                     self._output_token_ids)
247

248
249
250
251
252
253
254
    def _update_cached_all_token_embeds(self):
        assert isinstance(self._prompt_embeds, torch.Tensor)
        self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
        if self._output_embeds is not None:
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds, self._output_embeds), dim=0)

255
256
257
258
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

259
    @property
260
    def prompt_token_ids(self) -> tuple[int, ...]:
261
262
263
264
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
265
        raise NotImplementedError
266

267
268
    @property
    def prompt_token_ids_array(self) -> array:
269
270
271
272
273
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
274
275
        return self._prompt_token_ids

276
    @property
277
    def output_token_ids(self) -> tuple[int, ...]:
278
279
280
        return tuple(self._output_token_ids)

    @output_token_ids.setter
281
282
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
283
284
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
285
286
        self._update_cached_all_tokens()

287
288
289
290
291
292
293
294
295
    @property
    def output_embeds(self) -> Optional[torch.Tensor]:
        return self._output_embeds

    @output_embeds.setter
    def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
        self._output_token_embeds = new_output_token_embeds
        self._update_cached_all_token_embeds()

296
297
    @property
    def output_token_ids_array(self) -> array:
298
299
300
301
302
303
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
304
305
        return self._output_token_ids

306
307
308
309
310
311
312
313
314
    @property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        return self._prompt_embeds

    @prompt_embeds.setter
    def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
        self._prompt_embeds = prompt_embeds
        self._update_cached_all_token_embeds()

315
316
317
318
319
320
321
322
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

323
324
325
326
    def append_token_id(self,
                        token_id: int,
                        logprob: float,
                        token_embed: Optional[torch.Tensor] = None) -> None:
327
        self._output_token_ids.append(token_id)
328
        self._new_appended_tokens.append(token_id)
329
        self._cached_all_token_ids.append(token_id)
330
        self._cumulative_logprob += logprob
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        if token_embed is not None:
            # Do not pass in with batch or sequence dimensions
            assert token_embed.ndim == 1
            token_embed = token_embed.detach().cpu().unsqueeze(0)
            if self._output_embeds is None:
                self._output_embeds = token_embed
            else:
                self._output_embeds = torch.cat(
                    (self._output_embeds, token_embed), dim=0)
            assert self._cached_all_token_embeds is not None
            self._cached_all_token_embeds = torch.cat(
                (self._cached_all_token_embeds,
                 token_embed.to(device=self._cached_all_token_embeds.device)),
                dim=0)
345
346

    def get_len(self) -> int:
347
        return len(self._output_token_ids) + len(self._prompt_token_ids)
348

349
    def get_prompt_len(self) -> int:
350
        return len(self._prompt_token_ids)
351

352
    def get_output_len(self) -> int:
353
        return len(self._output_token_ids)
354

355
    def get_token_ids(self) -> list[int]:
356
        return self._cached_all_token_ids
357

358
359
360
    def get_token_embeddings(self) -> Optional[torch.Tensor]:
        return self._cached_all_token_embeds

361
362
    def get_prefix_token_ids(
            self, num_tokens: int
363
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
364
        """Get prefix tokens, and make the return value hashable"""
365
        prompt_length = self.get_prompt_len()
366
367
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
368
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
369
370
371
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

372
373
374
375
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

376
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
377
378
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
379
380
381
382
383
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
384

385
386
387
388
389
390
391
392
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

393
    def reset_state_for_recompute(self) -> None:
394
395
396
397
398
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
399
        self._stage = SequenceStage.PREFILL
400
        self._new_appended_tokens = []
401
402

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
403
        """Return the number of prefill tokens that are not computed."""
404
405
406
407
408
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

409
    def get_last_token_id(self) -> int:
410
411
412
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
413

414
    def get_prompt_token_ids(self) -> tuple[int, ...]:
415
416
        return self.prompt_token_ids

417
    def get_output_token_ids(self) -> tuple[int, ...]:
418
419
        return self.output_token_ids

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

435
436
437
438
    @property
    def stage(self) -> SequenceStage:
        return self._stage

439
440
    def __repr__(self) -> str:
        return (f"SequenceData("
441
                f"prompt_token_ids={self._prompt_token_ids}, "
442
443
                f"prompt_embeds.shape="
                f"{getattr(self._prompt_embeds, 'shape', None)}, "
444
445
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
446
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
447
448


Woosuk Kwon's avatar
Woosuk Kwon committed
449
class Sequence:
450
    """Stores the data, status, and block information of a sequence.
451

452
453
454
455
456
    The sequence is constructed from the
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
    or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    (for encoder-decoder) instance passed in through the `inputs`
    constructor argument.
457

458
459
    Args:
        seq_id: The ID of the sequence.
460
        inputs: The inputs of the sequence.
461
462
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
463
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
464
        lora_request: LoRA request.
465
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
466
467

    def __init__(
468
469
        self,
        seq_id: int,
470
        inputs: SingletonInputs,
471
472
473
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
474
475
    ) -> None:
        self.seq_id = seq_id
476
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
477
        self.block_size = block_size
478
        self.eos_token_id = eos_token_id
479
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
480

481
482
483
484
        self.data = SequenceData.from_seqs(
            self.prompt_token_ids,
            prompt_embeds=self.inputs["prompt_embeds"]
            if self.inputs["type"] == "embeds" else None)
485
        self.output_logprobs: SampleLogprobs = []
486
        self.output_text = ""
487

488
        self.status = SequenceStatus.WAITING
489
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
490

491
        # These are used to keep track of delta outputs
492
        self._last_output_token_ids_offset: int = 0
493
494
        self._last_output_text_offset: int = 0

495
496
497
498
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
499
        self.tokens: Optional[list[str]] = None
500

501
502
    @property
    def n_blocks(self) -> int:
503
        return (self.get_len() + self.block_size - 1) // self.block_size
504

505
    @property
506
    def prompt(self) -> Optional[str]:
507
508
        if self.inputs["type"] == "embeds":
            return None
509
        return self.inputs.get("prompt")
510

511
    @property
512
    def prompt_token_ids(self) -> list[int]:
513
514
        if self.inputs["type"] == "embeds":
            return [0] * len(self.inputs["prompt_embeds"])
515
        return self.inputs["prompt_token_ids"]
516

517
    @property
518
    def token_type_ids(self) -> list[int]:
519
520
        if self.inputs["type"] == "embeds":
            return []
521
        return self.inputs.get("token_type_ids", [])
522

523
    @property
524
525
526
527
    def multi_modal_data(self) -> MultiModalKwargs:
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_kwargs"]

528
        return MultiModalKwargs()
529

530
531
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
532
533
        if self.inputs["type"] == "multimodal":
            return self.inputs["mm_placeholders"]
534

535
        return {}
536

537
538
539
540
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

541
542
543
544
545
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

546
547
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
548
549
550
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
551
552
553
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
554
555
556
557
558
559
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

560
561
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
562
563
564
565
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
566
567
568
569
570
571
572
573
574
575
576
577
578

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

579
580
581
        if num_new_tokens == 0:
            return []

582
        return self.data._cached_all_token_ids[-num_new_tokens:]
583

584
    def hash_of_block(self, logical_idx: int) -> int:
585
586
        # TODO This can produce incorrect hash when block size > prompt size

587
        # Compute the number of tokens in the sequence
588
589
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
590
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
591
592
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
593

594
595
596
597
598
599
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
600
        if self.lora_int_id == 0:
601
602
603
604
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
605
        return hash(self.lora_int_id)
606

607
608
609
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

610
611
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
612
        self.data.reset_state_for_recompute()
613

614
615
616
617
    def append_token_id(self,
                        token_id: int,
                        logprobs: dict[int, Logprob],
                        token_embed: Optional[torch.Tensor] = None) -> None:
618
619
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
620
621
        self.data.append_token_id(token_id, logprobs[token_id].logprob,
                                  token_embed)
622

Woosuk Kwon's avatar
Woosuk Kwon committed
623
    def get_len(self) -> int:
624
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
625

626
627
628
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

629
630
631
    def get_output_len(self) -> int:
        return self.data.get_output_len()

632
    def get_token_ids(self) -> list[int]:
633
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
634

635
    def get_prompt_token_ids(self) -> tuple[int, ...]:
636
637
        return self.data.get_prompt_token_ids()

638
    def get_last_token_id(self) -> int:
639
        return self.data.get_last_token_id()
640

641
    def get_output_token_ids(self) -> tuple[int, ...]:
642
        return self.data.get_output_token_ids()
643
644
645
646

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

647
648
649
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

650
651
652
653
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
654

655
656
657
658
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
659
660
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
661
662
663
664
665
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

666
667
668
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

669
670
671
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
672
    def __repr__(self) -> str:
673
674
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
675
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
676

Woosuk Kwon's avatar
Woosuk Kwon committed
677

678
679
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
680
681
682
683
684
685
686
687
688
689
690
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
691
class SequenceGroup:
692
693
694
695
696
697
698
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
699
        lora_request: LoRA request.
700
        pooling_params: The parameters used to generate the pooler
701
            for a pooling model.
702
        pooled_data: The extracted hidden states from a pooling model.
703
704
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
705
        trace_headers: OpenTelemetry trace headers.
706
        priority: User-defined priority of the request.
707
        draft_size: The number of speculative tokens plus one from the target
708
                    model; equal to max number of tokens a step can generate
709
                    for single-draft speculative decoding but larger than
710
                    that for multi-draft SD (currently not supported).
711
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
712

713
714
715
716
717
718
719
720
721
722
723
724
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
725
        self.request_id = request_id
726
        self.seqs = seqs
727
        self.first_seq = seqs[0]
728
        self.arrival_time = arrival_time
729
        self.is_single_seq = len(seqs) == 1
730
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
731

732
        self.sampling_params = sampling_params
733
734
735
736
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
737
                                      time_in_queue=None)
738
        self.last_token_latency = 0.0
739
        self.lora_request = lora_request
740
        self.prompt_logprobs: Optional[PromptLogprobs] = None
741
        self.state = SequenceGroupState()
742
        self.pooling_params = pooling_params
743
        self.pooled_data = pooled_data
744
        self.encoder_seq = encoder_seq
745
        self.trace_headers = trace_headers
746
        self.priority = priority
747

748
749
        self.cached_request_output = None

750
    @property
751
    def prompt(self) -> Optional[str]:
752
        return self.first_seq.prompt
753
754

    @property
755
    def prompt_token_ids(self) -> list[int]:
756
        return self.first_seq.prompt_token_ids
757

758
759
760
761
762
763
764
765
766
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
767
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
768
769
770
771
772
773
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

774
    @property
775
    def token_type_ids(self) -> Optional[list[int]]:
776
777
        return self.first_seq.token_type_ids

778
    @property
779
    def multi_modal_data(self) -> MultiModalKwargs:
780
781
782
783
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
784
        return MultiModalKwargs()
Woosuk Kwon's avatar
Woosuk Kwon committed
785

786
787
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
788
789
790
791
792
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
793

794
795
796
797
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

798
    def set_last_token_time(self, now: float) -> None:
799
        """Sets the last token time for Request level timings."""
800
801
802
803
804
        # If still in prefill phase, assertion fails.
        assert not self.is_prefill(), (
            "seq_group.set_last_token_time() should not be called "
            "if the seq_group is in prefill phase.")
        self.last_token_latency = now - self.metrics.last_token_time
805
        self.metrics.last_token_time = now
806
807
808
809
810
811
812

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
        assert not self.is_prefill(), (
            "seq_group.get_last_token_latency() should not be called "
            "if the seq_group is in prefill phase.")
        return self.last_token_latency
813

814
815
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
816
817
818
819
820
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
821
                and self.first_seq.get_output_len() == 1):
822
823
824
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
825
826
        """Sets the first scheduled time and time in queue for Request
        level timings."""
827
828
829
830
831
832
833
834
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

835
836
837
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
838
839
840
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
841

842
843
844
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
845
    ) -> list[Sequence]:
846
847
        if status is None:
            return self.seqs
848

849
850
851
852
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
853

854
855
856
857
858
859
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

860
    def get_finished_seqs(self) -> list[Sequence]:
861
862
863
864
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
865

866
867
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
868
869
870
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
871
872

    def get_num_uncomputed_tokens(self) -> int:
873
        num_uncomputed_tokens = 0
874
875
876
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
877
        return num_uncomputed_tokens
878

879
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
880
881
882
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
883
            return len(self.seqs)
884

885
886
887
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

888
        return len(self.get_seqs(status))
889

890
    def num_finished_seqs(self) -> int:
891
892
893
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
894

Woosuk Kwon's avatar
Woosuk Kwon committed
895
    def is_finished(self) -> bool:
896
897
898
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
899

900
    def is_prefill(self) -> bool:
901
        return self.first_seq.is_prefill()
902

Woosuk Kwon's avatar
Woosuk Kwon committed
903
    def __repr__(self) -> str:
904
905
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
906
                f"num_seqs={len(self.seqs)})")
907

908
909
910
911
    def uses_prompt_embeds(self) -> bool:
        """Returns True if the sequence group uses input embeds."""
        return any(seq.data.prompt_embeds is not None for seq in self.seqs)

912

913
914
915
916
917
918
919
920
921
922
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
923
    seq_data_delta: dict[int, SequenceDataDelta]
924
    request_id: str
925
    block_tables: dict[int, list[int]]
926
927
928
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
929
    computed_block_nums: Optional[list[int]] = None
930
931
932
933
934
935
936
937
938
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
939
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
940
941
942
943
944
945
946
947

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
948
949
950
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
951
952
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
953
        lora_request: LoRA request.
954
955
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
956
        state: Internal state tied to this sequence group.
957
        multi_modal_data: Multi modal data.
958
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
959
        encoder_seq_data: Optional sequence data for encoder prompt
960
                          (SequenceGroup.encoder_seq). Should be None
961
962
963
964
965
966
967
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
968
    """
969

970
971
    request_id: str
    is_prompt: bool
972
    seq_data: dict[int, SequenceData]
973
    sampling_params: Optional[SamplingParams]
974
    block_tables: dict[int, list[int]]
975
976
977
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
978
    computed_block_nums: Optional[list[int]] = None
979
980
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
981
    token_type_ids: Optional[list[int]] = None
982
983
    multi_modal_data: Optional[Union[MultiModalKwargs,
                                     dict[str, "NestedTensors"]]] = None
984
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
985
    encoder_seq_data: Optional[SequenceData] = None
986
    cross_block_table: Optional[list[int]] = None
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
1001
            else:
1002
                self.token_chunk_size = 1
1003

1004
1005
1006
1007
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1022
1023
1024
1025
1026
1027
1028
1029
1030
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1031

1032
    def finish_step(self) -> None:
1033
        assert self.state is not None
1034
1035
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1036
1037
        self.state.current_step += 1

1038

1039
1040
1041
1042
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1043
1044
1045
1046
1047
1048
1049
1050
1051
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1052
1053
    parent_seq_id: int
    output_token: int
1054
    logprobs: dict[int, Logprob]
1055
    output_embed: Optional[torch.Tensor] = None
1056
1057

    def __repr__(self) -> str:
1058
1059
        output_embed_shape = \
            self.output_embed.shape if self.output_embed is not None else None
Zhuohan Li's avatar
Zhuohan Li committed
1060
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1061
                f"output_token={self.output_token}, "
1062
                f"output_embed.shape={output_embed_shape}, "
1063
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1064

1065
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1066
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1067
            raise NotImplementedError()
1068
1069
1070
1071
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1072
1073


1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1086
1087
1088
1089
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1090
    """The model output associated with a completion sequence group."""
1091
    __metaclass__ = SequenceGroupOutput
1092
    samples: list[SequenceOutput]
1093
1094
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1095
    step_index: Optional[int] = 0
1096
1097

    def __repr__(self) -> str:
1098
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1099
1100
                f"prompt_logprobs={self.prompt_logprobs})")

1101
    def __eq__(self, other: object) -> bool:
1102
        if not isinstance(other, CompletionSequenceGroupOutput):
1103
1104
1105
1106
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1107

1108
class PoolingSequenceGroupOutput(
1109
1110
1111
1112
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1113
    """The model output associated with a pooling sequence group."""
1114
    __metaclass__ = SequenceGroupOutput
1115
1116
1117
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1118

1119
1120
1121
1122
    def get_data_nbytes(self) -> int:
        data: torch.Tensor = self.data
        return data.nbytes

1123
    def __repr__(self) -> str:
1124
        return f"PoolingSequenceGroupOutput(data={self.data}"
1125
1126

    def __eq__(self, other: object) -> bool:
1127
        if not isinstance(other, PoolingSequenceGroupOutput):
1128
            raise NotImplementedError()
1129
        return self.data == other.data
1130
1131


1132
1133
1134
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1135
1136
1137
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
1138
    
1139
    Each stage also needs to handle its own kv_connector_output.
1140
1141
    """

1142
    tensors: dict[str, torch.Tensor]
1143
    kv_connector_output: Optional["KVConnectorOutput"]
1144

1145
1146
1147
1148
1149
1150
1151
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1152
1153
1154
1155
1156
1157
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1158
    def __setitem__(self, key: str, value: torch.Tensor):
1159
1160
        self.tensors[key] = value

1161
1162
1163
    def items(self):
        return self.tensors.items()

1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1174
1175
1176
1177
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1178
    """The output from a pooling operation in the pooling model."""
1179
    outputs: list[PoolingSequenceGroupOutput]
1180

1181
1182
1183
    def get_data_nbytes(self) -> int:
        return sum(o.get_data_nbytes() for o in self.outputs)

1184
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1185
1186
        return self.outputs[idx]

1187
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1198
def get_all_seq_ids(
1199
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1200
1201
1202
1203
1204
1205
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1206
def get_all_seq_ids_and_request_ids(
1207
1208
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1209
1210
1211
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1212
1213
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1214
1215
1216
1217
1218
1219
1220
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1221
1222
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1223
1224
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1225
    the target model to the proposer model.
1226
1227
1228

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1229
1230
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1231
    hidden_states: torch.Tensor
1232
    # The sequence group metadata list. Only needed for decode step.
1233
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1234
1235
1236
1237
1238
1239
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1240
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1241
1242

    def __post_init__(self):
1243
1244
1245
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1246
1247

    @property
1248
    def seq_ids(self) -> list[int]:
1249
        return self._seq_ids
1250

1251
1252
    def update(self,
               hidden_states: torch.Tensor,
1253
               seq_group_metadata_list: list[SequenceGroupMetadata],
1254
1255
1256
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1257
        assert len(seq_group_metadata_list) == len(hidden_states)
1258
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1259
1260
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1261
1262
1263
1264
1265
1266
1267
1268
1269
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1270
    def prune(self,
1271
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1272
1273
1274
1275
1276
1277
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1278
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1279
1280
        # Only keep sequence IDs that exist in self._seq_ids
        seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
1281
        if seq_ids != self._seq_ids:
1282
            # Batch contents changed - prune removed sequences.
1283
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1284
            self.hidden_states = self.hidden_states[index]
1285
1286
1287
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1288
            self._seq_ids = seq_ids
1289

1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1308

1309
1310
1311
1312
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1313
1314
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1315
    # The sequence group metadata list.
1316
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1317
                                        SequenceGroupMetadataDelta]]
1318
    # Blocks to swap in. List of CPU -> GPU block number.
1319
    blocks_to_swap_in: list[tuple[int,
1320
                                  int]] = msgspec.field(default_factory=list)
1321
    # Blocks to swap out. List of GPU -> CPU block number.
1322
    blocks_to_swap_out: list[tuple[int,
1323
                                   int]] = msgspec.field(default_factory=list)
1324
    # Blocks to copy. Source to dest block.
1325
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1326
1327
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1328
1329
1330
1331
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1332
1333
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1334
1335
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1336
    # Finished request ids since last step.
1337
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1338
1339
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1340
1341
    # Async callback
    async_callback: Optional[Callable] = None
1342
1343
1344
1345
1346
1347
1348

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1349
        assert first_seq_group.state is not None
1350
        return first_seq_group.state.remaining_steps == 1
1351
1352
1353
1354
1355
1356

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1357
1358
1359
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1360
1361

    def clone(
1362
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1363
                                                  SequenceGroupMetadataDelta]]
1364
1365
1366
1367
1368
1369
1370
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1371
            virtual_engine=self.virtual_engine,
1372
1373
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1374
            previous_hidden_states=self.previous_hidden_states,
1375
            num_steps=self.num_steps,
1376
1377
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1378
            if self.last_sampled_token_ids is not None else None,
1379
            async_callback=self.async_callback)
1380
1381
1382
1383
1384
1385
1386
1387
1388


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1389
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1390
1391

    # seq ids to be finished
1392
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1393
1394

    # seq id to finished sequences
1395
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1432
            params = original_params.clone()
1433
1434
1435
            params.n = 1
            if params.seed is not None:
                params.seed += i
1436
            seq_group = engine._add_processed_request(
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1456
            pooled_data=seq_group.pooled_data,
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1469
1470
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1471
        if self.streaming:
1472
1473
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1474
1475
1476
1477
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1478
        # when the last sequences finishes, and then return None for the
1479
        # rest of the time
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None