sequence.py 56.4 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from abc import ABC, abstractmethod
5
from array import array
6
from collections import defaultdict
7
from dataclasses import dataclass, field
8
from functools import cached_property, reduce
9
10
from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List,
                    Mapping, Optional)
11
12
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
import msgspec
15
16
import torch

17
from vllm.inputs.parse import is_encoder_decoder_inputs
18
from vllm.lora.request import LoRARequest
19
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
20
from vllm.pooling_params import PoolingParams
21
from vllm.prompt_adapter.request import PromptAdapterRequest
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
if TYPE_CHECKING:
25
    from vllm.inputs import SingletonInputs
26

27
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
28

29
30
VLLM_INVALID_TOKEN_ID = -1

31

32
33
34
35
36
def array_full(token_id: int, count: int):
    """:class:`array` equivalent of :func:`numpy.full`."""
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


37
38
39
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
40
41
@dataclass
class Logprob:
42
43
44
45
46
47
48
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
49
    logprob: float
50
    rank: Optional[int] = None
51
52
53
    decoded_token: Optional[str] = None


54
55
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
56
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
57
# {token_id -> logprob} for each sequence group.
58
SampleLogprobs = List[Dict[int, Logprob]]
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
class SequenceStatus(enum.IntEnum):
62
    """Status of a sequence."""
63
64
65
66
67
68
69
70
71
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
75
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
76
77
78
79
80
81
82

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
83
84
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
85
        elif status == SequenceStatus.FINISHED_IGNORED:
86
87
88
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
89
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
90
91
92
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
93

94

95
96
97
98
99
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


100
101
102
103
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

104
    Attributes:
105
106
107
108
109
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
110
111
112
113
114
115
116
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
117
118
119
120
121
122
123
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
124
125
126
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
127
128


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
    new_output_token_ids: List[int]
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
146
147
148
149
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
150
151
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
152
153
154
155
156
157

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    # NOTE: we cannot use Union[List, array] because msgspec cannot support
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
    _prompt_token_ids_tuple: Tuple[int,
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
    _stage: SequenceStage = SequenceStage.PREFILL
    _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
    _new_appended_tokens: List[int] = msgspec.field(default_factory=list)

177
178
179
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

180
    @staticmethod
181
182
183
184
185
186
187
188
189
    def from_prompt_token_counts(
            *token_counts: Tuple[int, int]) -> "SequenceData":
        """
        Construct a :class:`SequenceData` instance by concatenating
        prompt token sequences.

        Each tuple represents one token sequence, expressed in the form
        :code:`(token_id, count)`.
        """
190
        if len(token_counts) == 0:
191
192
            return SequenceData.from_seqs([])

193
194
195
196
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
197

198
        return SequenceData(prompt_token_ids_arr)
199
200
201
202
203
204

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
    ) -> "SequenceData":
205
206
207
208
        """
        Construct a :class:`SequenceData` instance from prompt and output
        token sequences.
        """
209
210
211
212
213
214
215
216
217
218
219
220
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
            return SequenceData(prompt_token_ids_arr)

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
                            _output_token_ids=output_token_ids_arr)

221
222
223
224
225
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
            self._prompt_token_ids)
226
227
228
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
229
230
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
231
232
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
233

234
235
236
237
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

238
239
240
241
242
243
    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
244
        raise NotImplementedError
245

246
247
    @property
    def prompt_token_ids_array(self) -> array:
248
249
250
251
252
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
253
254
        return self._prompt_token_ids

255
256
257
258
259
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
260
261
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
262
263
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
264
265
        self._update_cached_all_tokens()

266
267
    @property
    def output_token_ids_array(self) -> array:
268
269
270
271
272
273
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
274
275
        return self._output_token_ids

276
277
278
279
280
281
282
283
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

284
    def append_token_id(self, token_id: int, logprob: float) -> None:
285
        self._output_token_ids.append(token_id)
286
        self._new_appended_tokens.append(token_id)
287
        self._cached_all_token_ids.append(token_id)
288
        self._cumulative_logprob += logprob
289
290

    def get_len(self) -> int:
291
        return len(self._output_token_ids) + len(self._prompt_token_ids)
292

293
    def get_prompt_len(self) -> int:
294
        return len(self._prompt_token_ids)
295

296
    def get_output_len(self) -> int:
297
        return len(self._output_token_ids)
298

299
    def get_token_ids(self) -> List[int]:
300
        return self._cached_all_token_ids
301

302
303
304
305
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
306
        prompt_length = self.get_prompt_len()
307
308
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
309
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
310
311
312
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

313
314
315
316
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

317
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
318
319
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
320
321
322
323
324
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
325

326
    def reset_state_for_recompute(self) -> None:
327
328
329
330
331
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
332
        self._stage = SequenceStage.PREFILL
333
        self._new_appended_tokens = []
334
335

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
336
        """Return the number of prefill tokens that are not computed."""
337
338
339
340
341
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

342
    def get_last_token_id(self) -> int:
343
344
345
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
346

347
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
348
349
        return self.prompt_token_ids

350
    def get_output_token_ids(self) -> Tuple[int, ...]:
351
352
        return self.output_token_ids

353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

368
369
370
371
    @property
    def stage(self) -> SequenceStage:
        return self._stage

372
373
    def __repr__(self) -> str:
        return (f"SequenceData("
374
                f"prompt_token_ids={self._prompt_token_ids}, "
375
376
377
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
                f"get_num_computed_tokens={self.get_num_computed_tokens()}")
378
379


Woosuk Kwon's avatar
Woosuk Kwon committed
380
class Sequence:
381
382
    """Stores the data, status, and block information of a sequence.

383
384
    The sequence is constructed from the :code:`SingletonInputs` instance
    passed in through the :code:`inputs` constructor argument.
385

386
    For encoder/decoder models, SingletonInputs encapsulates both a
387
388
389
    decoder and encoder prompt, creating an ambiguity about which
    prompt to construct the sequence from. The `from_decoder_prompt`
    constructor argument signals whether to construct the Sequence
390
    from the SingletonInputs decoder prompt, or encoder prompt.
391

392
393
    Args:
        seq_id: The ID of the sequence.
394
        inputs: The inputs of the sequence.
395
396
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
397
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
398
        lora_request: LoRA request.
399
        prompt_adapter_request: Prompt Adapter request.
400
401
402
        from_decoder_prompt: Construct Sequence from SingletonInputs decoder
                             prompt (True) or encoder prompt (False.) Must be
                             True for decoder-only model.
403

404
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406

    def __init__(
407
408
        self,
        seq_id: int,
409
        inputs: "SingletonInputs",
410
411
412
413
414
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        from_decoder_prompt: bool = True,
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
    ) -> None:
        self.seq_id = seq_id
417
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
418
        self.block_size = block_size
419
        self.eos_token_id = eos_token_id
420
        self.lora_request = lora_request
421
        self.prompt_adapter_request = prompt_adapter_request
422
423
424
        self.from_decoder_prompt = from_decoder_prompt

        # For decoder-only models, a Sequence is constructed
425
        # from an DecoderOnlyInputs instance (the `inputs` arg.)
426
427
428
429
        #
        # For encoder/decoder models the same `inputs`
        # instance could be utilized to construct either an
        # encoder sequence or a decoder sequence, because
430
        # `DecoderOnlyInputs` has both decoder- and encoder-oriented
431
432
433
434
435
436
        # member variables (i.e. it encapsulates both an encoder
        # and a decoder prompt.) The decision of which type of sequence
        # to generate is determined by the `from_decoder_prompt` argument.
        #
        # When constructing a encoder sequence
        # (`from_decoder_prompt` False) it matters that
437
        # the `DecoderOnlyInputs` instance stored in `inputs` is valid
438
439
440
441
442
443
444
        # in the sense that its encoder-related member variables are
        # populated; below, an exception is raised if this is
        # not the case.
        #
        # When constructing a decoder sequence (`from_decoder_prompt` True)
        # it does not matter whether `inputs` has its encoder-related
        # member variables populated.
445
        if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)):
446
447
448
            raise ValueError("Cannot extract encoder input prompt from "
                             f"invalid input {inputs}; did you forget the "
                             "encoder input prompt fields?")
Woosuk Kwon's avatar
Woosuk Kwon committed
449

450
        self.data = SequenceData.from_seqs(self.prompt_token_ids)
451
        self.output_logprobs: SampleLogprobs = []
452
        self.output_text = ""
453

454
        self.status = SequenceStatus.WAITING
455
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
456

457
        # These are used to keep track of delta outputs
458
        self._last_output_token_ids_offset: int = 0
459
460
        self._last_output_text_offset: int = 0

461
462
463
464
465
466
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

467
468
    @property
    def n_blocks(self) -> int:
469
        return (self.get_len() + self.block_size - 1) // self.block_size
470

471
    @cached_property
472
    def prompt(self) -> Optional[str]:
473
        # Select decoder or encoder input prompt str, as appropriate
474
475
476
        prompt_key: str = ("prompt"
                           if self.from_decoder_prompt else "encoder_prompt")

477
        return cast(Optional[str], self.inputs.get(prompt_key))
478

479
    @cached_property
480
    def prompt_token_ids(self) -> List[int]:
481
        # Select decoder or encoder input prompt token ids, as appropriate
482
483
484
485
486
        prompt_token_ids_key: str = ("prompt_token_ids"
                                     if self.from_decoder_prompt else
                                     "encoder_prompt_token_ids")

        # Cache computed prompt token ids
487
        return cast(List[int], self.inputs.get(prompt_token_ids_key))
488
489

    @property
490
    def multi_modal_data(self) -> MultiModalDataDict:
491
492
493
494
        inputs = self.inputs

        if (inputs.get("multi_modal_data")
                and inputs.get("encoder_multi_modal_data")):
495
496
497
            raise ValueError(
                "Multi-modal data in both encoder and decoder is not supported."
            )
498
499

        return cast(
500
            MultiModalDataDict,
501
502
503
            (inputs.get("multi_modal_data")
             or inputs.get("encoder_multi_modal_data") or {}),
        )
504

505
506
507
508
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
        return self.inputs.get("multi_modal_placeholders") or {}

509
510
511
512
    @property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
        return self.inputs.get("mm_processor_kwargs") or {}

513
514
515
516
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

517
518
519
520
521
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

522
523
524
525
526
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

527
528
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
529
530
531
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
532
533
534
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
535
536
537
538
539
540
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

541
542
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
543
544
545
546
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
547
548
549
550
551
552
553
554
555
556
557
558
559

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

560
561
562
        if num_new_tokens == 0:
            return []

563
        return self.data._cached_all_token_ids[-num_new_tokens:]
564

565
    def hash_of_block(self, logical_idx: int) -> int:
566
567
        # TODO This can produce incorrect hash when block size > prompt size

568
        # Compute the number of tokens in the sequence
569
570
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
571
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
572
573
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
574
575
576
577

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

578
579
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
580
        self.data.reset_state_for_recompute()
581

582
583
    def append_token_id(self, token_id: int, logprobs: Dict[int,
                                                            Logprob]) -> None:
584
585
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
586
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
587

Woosuk Kwon's avatar
Woosuk Kwon committed
588
    def get_len(self) -> int:
589
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
590

591
592
593
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

594
595
596
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
597
    def get_token_ids(self) -> List[int]:
598
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
599

600
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
601
602
        return self.data.get_prompt_token_ids()

603
    def get_last_token_id(self) -> int:
604
        return self.data.get_last_token_id()
605

606
607
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
608
609
610
611

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

612
613
614
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

615
616
617
618
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
619

620
621
622
623
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
624
625
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
626
627
628
629
630
631
632
633
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
634
    def __repr__(self) -> str:
635
636
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
637
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
638

Woosuk Kwon's avatar
Woosuk Kwon committed
639

640
641
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
642
643
644
645
646
647
648
649
650
651
652
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
653
class SequenceGroup:
654
655
656
657
658
659
660
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
661
        lora_request: LoRA request.
662
663
664
665
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
666
667
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
668
        trace_headers: OpenTelemetry trace headers.
669
        prompt_adapter_request: Prompt Adapter request.
670
        priority: User-defined priority of the request.
671
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
672
673
674

    def __init__(
        self,
675
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
676
        seqs: List[Sequence],
677
        arrival_time: float,
678
        sampling_params: Optional[SamplingParams] = None,
679
        lora_request: Optional[LoRARequest] = None,
680
681
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
682
        encoder_seq: Optional[Sequence] = None,
683
        trace_headers: Optional[Mapping[str, str]] = None,
684
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
685
        priority: int = 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
686
    ) -> None:
687
        self.request_id = request_id
688
        self.seqs = seqs
689
        self.first_seq = seqs[0]
690
        self.arrival_time = arrival_time
691
        self.is_single_seq = len(seqs) == 1
692
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
693

694
        self.sampling_params = sampling_params
695
696
697
698
699
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
700
        self.lora_request = lora_request
701
        self.prompt_logprobs: Optional[PromptLogprobs] = None
702
        self.state = SequenceGroupState()
703
704
        self.embeddings = embeddings
        self.pooling_params = pooling_params
705
        self.prompt_adapter_request = prompt_adapter_request
706
        self.encoder_seq = encoder_seq
707
        self.trace_headers = trace_headers
708
        self.priority = priority
709

710
711
        self.cached_request_output = None

712
    @property
713
    def prompt(self) -> Optional[str]:
714
        return self.first_seq.prompt
715
716
717

    @property
    def prompt_token_ids(self) -> List[int]:
718
        return self.first_seq.prompt_token_ids
719

720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

736
    @property
737
    def multi_modal_data(self) -> MultiModalDataDict:
738
        return self.first_seq.multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
739

740
741
742
743
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
        return self.first_seq.multi_modal_placeholders

744
745
    @property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
746
        return self.first_seq.mm_processor_kwargs
747

748
749
750
751
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

752
753
754
755
756
757
758
759
760
761
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

762
763
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
764
765
        self.state.current_step = 0

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

791
    def get_last_latency(self, now: float) -> float:
792
793
794
795
796
797
798
799
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
800
801
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
802
803
        return latency

804
805
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
806
807
808
809
810
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
811
                and self.first_seq.get_output_len() == 1):
812
813
814
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
815
816
        """Sets the first scheduled time and time in queue for Request
        level timings."""
817
818
819
820
821
822
823
824
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

825
826
827
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
828
        return 0 if self.first_seq.is_finished() else 1
829

830
831
832
833
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
834
835
        if status is None:
            return self.seqs
836

837
        return self.seqs if self.first_seq.status == status else []
838

839
840
841
842
843
844
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

845
    def get_finished_seqs(self) -> List[Sequence]:
846
        return self.seqs if self.first_seq.is_finished() else []
847

848
849
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
850
851
852
        seq = self.first_seq
        if not seq.is_finished():
            seq.data.update_num_computed_tokens(num_new_computed_tokens)
853
854

    def get_num_uncomputed_tokens(self) -> int:
855
        num_uncomputed_tokens = 0
856
857
858
        seq = self.first_seq
        if not seq.is_finished():
            num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
859
        return num_uncomputed_tokens
860

861
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
862
863
864
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
865
            return len(self.seqs)
866

867
868
869
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

870
        return len(self.get_seqs(status))
871

872
    def num_finished_seqs(self) -> int:
873
        return 1 if self.first_seq.is_finished() else 0
Woosuk Kwon's avatar
Woosuk Kwon committed
874

Woosuk Kwon's avatar
Woosuk Kwon committed
875
    def is_finished(self) -> bool:
876
        return self.first_seq.is_finished()
Woosuk Kwon's avatar
Woosuk Kwon committed
877

878
    def is_prefill(self) -> bool:
879
        return self.first_seq.is_prefill()
880

Woosuk Kwon's avatar
Woosuk Kwon committed
881
    def __repr__(self) -> str:
882
883
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
884
                f"num_seqs={len(self.seqs)})")
885
886


887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
    seq_data_delta: Dict[int, SequenceDataDelta]
    request_id: str
    block_tables: Dict[int, List[int]]
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
913
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
914
915
916
917
918
919
920
921

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
922
923
924
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
925
926
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
927
        lora_request: LoRA request.
928
929
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
930
        state: Internal state tied to this sequence group.
931
        multi_modal_data: Multi modal data.
932
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
933
934
935
936
937
938
939
940
941
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
942
        prompt_adapter_request: Prompt Adapter request.
943
    """
944

945
946
947
    request_id: str
    is_prompt: bool
    seq_data: Dict[int, SequenceData]
948
    sampling_params: Optional[SamplingParams]
949
950
951
952
953
954
955
956
957
958
    block_tables: Dict[int, List[int]]
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
    # "MultiModalDataDict" types. We have to use Any due to msgspec
    # doesn't allow to have union of 2 different dicts.
    multi_modal_data: Optional[Any] = None
959
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
960
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
    encoder_seq_data: Optional[SequenceData] = None
    cross_block_table: Optional[List[int]] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
978
            else:
979
                self.token_chunk_size = 1
980

981
982
983
984
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

985
    @property
986
987
988
989
990
991
992
993
994
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1009
1010
1011
1012
1013
1014
1015
1016
1017
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1018

1019
    def finish_step(self) -> None:
1020
        assert self.state is not None
1021
1022
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1023
1024
        self.state.current_step += 1

1025

1026
1027
1028
1029
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1030
1031
1032
1033
1034
1035
1036
1037
1038
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1039
1040
1041
    parent_seq_id: int
    output_token: int
    logprobs: Dict[int, Logprob]
1042
1043

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
1044
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1045
1046
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1047

1048
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1049
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1050
            raise NotImplementedError()
1051
1052
1053
1054
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1055
1056


1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1069
1070
1071
1072
1073
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
    __metaclass__ = SequenceGroupOutput
1074
    """The model output associated with a completion sequence group."""
1075
1076
1077
    samples: List[SequenceOutput]
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1078
1079

    def __repr__(self) -> str:
1080
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1081
1082
                f"prompt_logprobs={self.prompt_logprobs})")

1083
    def __eq__(self, other: object) -> bool:
1084
        if not isinstance(other, CompletionSequenceGroupOutput):
1085
1086
1087
1088
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1089

1090
1091
1092
1093
1094
class EmbeddingSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1095
    """The model output associated with an embedding sequence group."""
1096
1097
    __metaclass__ = SequenceGroupOutput
    embeddings: List[int]
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


1109
1110
1111
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1138
1139
1140
1141
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1142
1143
1144
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

1145
1146
    # lazy import to avoid circular import
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
1147
    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
1148

1149
    def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1163
1164
1165
1166
1167
1168
1169
1170
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1171
1172
1173
1174
1175
1176
1177
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
1178
    request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set)
1179
1180
1181
1182
1183
1184
1185
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1186
1187
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1188
1189
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1190
    the target model to the proposer model.
1191
1192
1193

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1194
1195
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1196
    hidden_states: torch.Tensor
1197
1198
1199
1200
1201
1202
1203
1204
    # The sequence group metadata list. Only needed for decode step.
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1205
1206
1207
    _seq_ids: List[int] = msgspec.field(default_factory=list)

    def __post_init__(self):
1208
1209
1210
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1211
1212
1213
1214

    @property
    def seq_ids(self) -> List[int]:
        return self._seq_ids
1215

1216
1217
1218
1219
1220
1221
    def update(self,
               hidden_states: torch.Tensor,
               seq_group_metadata_list: List[SequenceGroupMetadata],
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1222
        assert len(seq_group_metadata_list) == len(hidden_states)
1223
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1224
1225
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1226
1227
1228
1229
1230
1231
1232
1233
1234
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1235
1236
    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
1237
1238
1239
1240
1241
1242
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1243
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1244
        if seq_ids != self._seq_ids:
1245
            # Batch contents changed - prune removed sequences.
1246
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1247
            self.hidden_states = self.hidden_states[index]
1248
1249
1250
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1251
            self._seq_ids = seq_ids
1252

1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1271

1272
1273
1274
1275
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1276
1277
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1278
    # The sequence group metadata list.
1279
1280
    seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta]]
1281
    # Blocks to swap in. List of CPU -> GPU block number.
1282
1283
    blocks_to_swap_in: List[Tuple[int,
                                  int]] = msgspec.field(default_factory=list)
1284
    # Blocks to swap out. List of GPU -> CPU block number.
1285
1286
    blocks_to_swap_out: List[Tuple[int,
                                   int]] = msgspec.field(default_factory=list)
1287
    # Blocks to copy. Source to dest block.
1288
    blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
1289
1290
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1291
1292
1293
1294
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1295
1296
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1297
1298
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1299
    # Finished request ids since last step.
1300
    finished_requests_ids: List[str] = msgspec.field(default_factory=list)
1301
1302
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1303
1304
    # Async callback
    async_callback: Optional[Callable] = None
1305
1306
1307
1308
1309
1310
1311

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1312
        assert first_seq_group.state is not None
1313
1314
1315
1316
1317
1318
1319
1320
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1321
        assert first_seq_group.state is not None
1322
        return first_seq_group.state.remaining_steps == 1
1323
1324
1325
1326
1327
1328

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1329
1330
1331
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1332
1333

    def clone(
1334
1335
        self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                  SequenceGroupMetadataDelta]]
1336
1337
1338
1339
1340
1341
1342
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1343
            virtual_engine=self.virtual_engine,
1344
1345
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1346
            previous_hidden_states=self.previous_hidden_states,
1347
            num_steps=self.num_steps,
1348
1349
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1350
            if self.last_sampled_token_ids is not None else None,
1351
            async_callback=self.async_callback)
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
    seq_id_to_index: Dict[str, int] = field(default_factory=dict)

    # seq ids to be finished
    to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)

    # seq id to finished sequences
    finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        params = copy.deepcopy(original_params)
        params.n = 1
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1406
            seq_group = engine._add_processed_request(
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            embeddings=seq_group.embeddings,
            pooling_params=seq_group.pooling_params,
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            prompt_adapter_request=seq_group.prompt_adapter_request,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
        # for the first sequence, and then return None for the rest of
        # sequences
        if self.streaming:
            if self.seq_id_to_index[seq_group.request_id] == 0:
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
        # once after all sequences finish, and then return None for the
        # rest of the time

        if len(self.to_be_finished) > 0:
            return None

        assert self.assembled_seq_group is not None
        params = self.assembled_seq_group.sampling_params
        assert isinstance(params, SamplingParams)
        if not self.output_produced:
            self.output_produced = True
            if params._real_n is not None:
                # Get the top-n sequences.
                n = params._real_n or params.n
                seqs = self.assembled_seq_group.seqs
                sorting_key = lambda seq: seq.get_cumulative_logprob()
                sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                top_n_seqs = sorted_seqs[:n]
                self.assembled_seq_group.seqs = top_n_seqs
            return self.assembled_seq_group
        if self.output_produced:
            return None