sequence.py 52.4 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from abc import ABC, abstractmethod
5
from array import array
6
from collections import defaultdict
7
from dataclasses import dataclass
8
from functools import cached_property, reduce
9
10
11
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
import msgspec
14
15
import torch

16
from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs
17
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
18
from vllm.lora.request import LoRARequest
19
from vllm.pooling_params import PoolingParams
20
from vllm.prompt_adapter.request import PromptAdapterRequest
21
from vllm.sampling_params import SamplingParams
22
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
if TYPE_CHECKING:
25
    from vllm.multimodal.base import MultiModalDataDict
26

27
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
28

29
30
VLLM_INVALID_TOKEN_ID = -1

31
32
33
34

# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
35
36
@dataclass
class Logprob:
37
38
39
40
41
42
43
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
44
    logprob: float
45
    rank: Optional[int] = None
46
47
48
    decoded_token: Optional[str] = None


49
50
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
51
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
52
# {token_id -> logprob} for each sequence group.
53
SampleLogprobs = List[Dict[int, Logprob]]
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
class SequenceStatus(enum.IntEnum):
57
    """Status of a sequence."""
58
59
60
61
62
63
64
65
66
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
67
68
69

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
70
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
71
72
73
74
75
76
77

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
78
79
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
80
        elif status == SequenceStatus.FINISHED_IGNORED:
81
82
83
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
84
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
85
86
87
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
88

89

90
91
92
93
94
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


95
96
97
98
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

99
    Attributes:
100
101
102
103
104
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
105
106
107
108
109
110
111
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
112
113
114
115
116
117
118
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
119
120
121
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
122
123


124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
    new_output_token_ids: List[int]
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
141
142
143
144
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
145
146
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
147
148
149
150
151
152

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    # NOTE: we cannot use Union[List, array] because msgspec cannot support
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
    _prompt_token_ids_tuple: Tuple[int,
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
    _stage: SequenceStage = SequenceStage.PREFILL
    _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
    _new_appended_tokens: List[int] = msgspec.field(default_factory=list)

172
173
174
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

175
    @staticmethod
176
177
    def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
        if len(token_counts) == 0:
178
179
180
181
            return SequenceData.from_seqs([])

        arrs = [
            array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
182
            for token_id, count in token_counts
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        ]

        return SequenceData(reduce(array.__add__, arrs))

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
    ) -> "SequenceData":
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
            return SequenceData(prompt_token_ids_arr)

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
                            _output_token_ids=output_token_ids_arr)

204
205
206
207
208
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
            self._prompt_token_ids)
209
210
211
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
212
213
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
214
215
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
216

217
218
219
220
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

221
222
223
224
225
226
    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
227
        raise NotImplementedError
228

229
230
    @property
    def prompt_token_ids_array(self) -> array:
231
232
233
234
235
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
236
237
        return self._prompt_token_ids

238
239
240
241
242
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
243
244
245
    def output_token_ids(self, new_output_token_ids: List[int]) -> None:
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
246
247
        self._update_cached_all_tokens()

248
249
    @property
    def output_token_ids_array(self) -> array:
250
251
252
253
254
255
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
256
257
        return self._output_token_ids

258
259
260
261
262
263
264
265
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

266
    def append_token_id(self, token_id: int, logprob: float) -> None:
267
        self._output_token_ids.append(token_id)
268
        self._new_appended_tokens.append(token_id)
269
        self._cached_all_token_ids.append(token_id)
270
        self._cumulative_logprob += logprob
271
272

    def get_len(self) -> int:
273
        return len(self._output_token_ids) + len(self._prompt_token_ids)
274

275
    def get_prompt_len(self) -> int:
276
        return len(self._prompt_token_ids)
277

278
    def get_output_len(self) -> int:
279
        return len(self._output_token_ids)
280

281
    def get_token_ids(self) -> List[int]:
282
        return self._cached_all_token_ids
283

284
285
286
287
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
288
        prompt_length = self.get_prompt_len()
289
290
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
291
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
292
293
294
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

295
296
297
298
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

299
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
300
301
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
302
303
304
305
306
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
307

308
    def reset_state_for_recompute(self) -> None:
309
310
311
312
313
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
314
        self._stage = SequenceStage.PREFILL
315
        self._new_appended_tokens = []
316
317

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
318
        """Return the number of prefill tokens that are not computed."""
319
320
321
322
323
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

324
    def get_last_token_id(self) -> int:
325
326
327
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
328

329
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
330
331
        return self.prompt_token_ids

332
    def get_output_token_ids(self) -> Tuple[int, ...]:
333
334
        return self.output_token_ids

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

350
351
352
353
    @property
    def stage(self) -> SequenceStage:
        return self._stage

354
355
    def __repr__(self) -> str:
        return (f"SequenceData("
356
                f"prompt_token_ids={self._prompt_token_ids}, "
357
358
359
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
                f"get_num_computed_tokens={self.get_num_computed_tokens()}")
360
361


Woosuk Kwon's avatar
Woosuk Kwon committed
362
class Sequence:
363
364
    """Stores the data, status, and block information of a sequence.

365
366
367
368
369
370
371
372
373
    The sequence is constructed from the LLMInputs instance passed
    in through the `inputs` constructor argument.

    For encoder/decoder models, LLMInputs encapsulates both a
    decoder and encoder prompt, creating an ambiguity about which
    prompt to construct the sequence from. The `from_decoder_prompt`
    constructor argument signals whether to construct the Sequence
    from the LLMInputs decoder prompt, or encoder prompt.

374
375
    Args:
        seq_id: The ID of the sequence.
376
        inputs: The inputs of the sequence.
377
378
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
379
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
380
        lora_request: LoRA request.
381
        prompt_adapter_request: Prompt Adapter request.
382
383
384
        from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
                             (True) or encoder prompt (False.) Must be True
                             for decoder-only model.
385

386
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388

    def __init__(
389
390
391
392
393
394
395
396
        self,
        seq_id: int,
        inputs: "LLMInputs",
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        from_decoder_prompt: bool = True,
Woosuk Kwon's avatar
Woosuk Kwon committed
397
398
    ) -> None:
        self.seq_id = seq_id
399
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
400
        self.block_size = block_size
401
        self.eos_token_id = eos_token_id
402
        self.lora_request = lora_request
403
        self.prompt_adapter_request = prompt_adapter_request
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        self.from_decoder_prompt = from_decoder_prompt

        # For decoder-only models, a Sequence is constructed
        # from an LLMInputs instance (the `inputs` arg.)
        #
        # For encoder/decoder models the same `inputs`
        # instance could be utilized to construct either an
        # encoder sequence or a decoder sequence, because
        # `LLMInputs` has both decoder- and encoder-oriented
        # member variables (i.e. it encapsulates both an encoder
        # and a decoder prompt.) The decision of which type of sequence
        # to generate is determined by the `from_decoder_prompt` argument.
        #
        # When constructing a encoder sequence
        # (`from_decoder_prompt` False) it matters that
        # the `LLMInputs` instance stored in `inputs` is valid
        # in the sense that its encoder-related member variables are
        # populated; below, an exception is raised if this is
        # not the case.
        #
        # When constructing a decoder sequence (`from_decoder_prompt` True)
        # it does not matter whether `inputs` has its encoder-related
        # member variables populated.
        if not (from_decoder_prompt
                or is_valid_encoder_decoder_llm_inputs(inputs)):
            raise ValueError("Cannot extract encoder input prompt from "
                             f"invalid input {inputs}; did you forget the "
                             "encoder input prompt fields?")
Woosuk Kwon's avatar
Woosuk Kwon committed
432

433
        self.data = SequenceData.from_seqs(self.prompt_token_ids)
434
        self.output_logprobs: SampleLogprobs = []
435
        self.output_text = ""
436

437
        self.status = SequenceStatus.WAITING
438
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
439

440
        # These are used to keep track of delta outputs
441
        self._last_output_token_ids_offset: int = 0
442
443
        self._last_output_text_offset: int = 0

444
445
446
447
448
449
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

450
451
    @property
    def n_blocks(self) -> int:
452
        return (self.get_len() + self.block_size - 1) // self.block_size
453

454
    @cached_property
455
    def prompt(self) -> Optional[str]:
456
        # Select decoder or encoder input prompt str, as appropriate
457
458
459
        prompt_key: str = ("prompt"
                           if self.from_decoder_prompt else "encoder_prompt")

460
        return cast(Optional[str], self.inputs.get(prompt_key))
461

462
    @cached_property
463
    def prompt_token_ids(self) -> List[int]:
464
        # Select decoder or encoder input prompt token ids, as appropriate
465
466
467
468
469
        prompt_token_ids_key: str = ("prompt_token_ids"
                                     if self.from_decoder_prompt else
                                     "encoder_prompt_token_ids")

        # Cache computed prompt token ids
470
        return cast(List[int], self.inputs.get(prompt_token_ids_key))
471
472

    @property
473
    def multi_modal_data(self) -> "MultiModalDataDict":
474
475
476
477
478
479
480
481
482
        if self.inputs.get("multi_modal_data") and self.inputs.get(
                "encoder_multi_modal_data"):
            raise ValueError(
                "Multi-modal data in both encoder and decoder is not supported."
            )
        inputs = self.inputs
        return self.inputs.get("multi_modal_data") or (cast(
            EncoderDecoderLLMInputs,
            inputs).get("encoder_multi_modal_data")) or {}
483

484
485
486
487
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

488
489
490
491
492
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

493
494
495
496
497
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

498
499
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
500
501
502
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
503
504
505
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
506
507
508
509
510
511
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

512
513
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
514
515
516
517
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
518
519
520
521
522
523
524
525
526
527
528
529
530
531

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

        return self.data._cached_all_token_ids[-num_new_tokens:]
532

533
    def hash_of_block(self, logical_idx: int) -> int:
534
535
        # TODO This can produce incorrect hash when block size > prompt size

536
        # Compute the number of tokens in the sequence
537
538
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
539
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
540
541
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
542
543
544
545

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

546
547
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
548
        self.data.reset_state_for_recompute()
549

550
551
    def append_token_id(self, token_id: int, logprobs: Dict[int,
                                                            Logprob]) -> None:
552
553
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
554
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
555

Woosuk Kwon's avatar
Woosuk Kwon committed
556
    def get_len(self) -> int:
557
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
558

559
560
561
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

562
563
564
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
565
    def get_token_ids(self) -> List[int]:
566
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
567

568
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
569
570
        return self.data.get_prompt_token_ids()

571
    def get_last_token_id(self) -> int:
572
        return self.data.get_last_token_id()
573

574
575
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
576
577
578
579

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

580
    def get_beam_search_score(self,
581
                              length_penalty: float = 1.0,
582
583
584
585
586
587
588
589
590
591
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
592
            # NOTE: HF implementation does not count the EOS token
593
594
595
596
597
598
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

599
600
601
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

602
603
604
605
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
606

607
608
609
610
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
611
612
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
613
614
615
616
617
618
619
620
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
621
    def __repr__(self) -> str:
622
623
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
624
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
625

Woosuk Kwon's avatar
Woosuk Kwon committed
626

627
628
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
629
630
631
632
633
634
635
636
637
638
639
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
640
class SequenceGroup:
641
642
643
644
645
646
647
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
648
        lora_request: LoRA request.
649
650
651
652
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
653
654
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
655
        trace_headers: OpenTelemetry trace headers.
656
        prompt_adapter_request: Prompt Adapter request.
657
        priority: User-defined priority of the request.
658
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
659
660
661

    def __init__(
        self,
662
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
663
        seqs: List[Sequence],
664
        arrival_time: float,
665
        sampling_params: Optional[SamplingParams] = None,
666
        lora_request: Optional[LoRARequest] = None,
667
668
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
669
        encoder_seq: Optional[Sequence] = None,
670
        trace_headers: Optional[Mapping[str, str]] = None,
671
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
672
        priority: int = 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
673
    ) -> None:
674
        self.request_id = request_id
675
        self.seqs = seqs
676
        self.arrival_time = arrival_time
677
        self.is_single_seq = len(seqs) == 1
678
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
679

680
        self.sampling_params = sampling_params
681
682
683
684
685
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
686
        self.lora_request = lora_request
687
        self.prompt_logprobs: Optional[PromptLogprobs] = None
688
        self.state = SequenceGroupState()
689
690
        self.embeddings = embeddings
        self.pooling_params = pooling_params
691
        self.prompt_adapter_request = prompt_adapter_request
692
        self.encoder_seq = encoder_seq
693
        self.trace_headers = trace_headers
694
        self.priority = priority
695

696
697
        self.cached_request_output = None

698
    @property
699
    def prompt(self) -> Optional[str]:
700
701
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
702
        return self.seqs[0].prompt
703
704
705
706
707

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
708
        return self.seqs[0].prompt_token_ids
709

710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

726
    @property
727
    def multi_modal_data(self) -> "MultiModalDataDict":
728
729
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
730
        return self.seqs[0].multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
731

732
733
734
735
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

736
737
738
739
740
741
742
743
744
745
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

746
747
748
749
    def init_multi_step(self, num_scheduler_steps: int) -> None:
        self.state.num_steps = num_scheduler_steps
        self.state.current_step = 0

750
751
752
753
754
755
756
757
758
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
759
760
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
761
762
        return latency

763
764
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
765
766
767
768
769
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
770
                and self.seqs[0].get_output_len() == 1):
771
772
773
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
774
775
        """Sets the first scheduled time and time in queue for Request
        level timings."""
776
777
778
779
780
781
782
783
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

784
785
786
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
787
        if self.sampling_params and self.sampling_params.use_beam_search:
788
789
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
790
791
792
            best_of = self.sampling_params.best_of
            assert isinstance(best_of, int)
            return best_of
793
        else:
794
795
796
797
798
799
800
801
802
            if self.sampling_params:
                best_of = self.sampling_params.best_of
                assert isinstance(best_of, int)
                if best_of > self.num_seqs():
                    # At prompt stage, the sequence group is not yet filled up
                    # and only have one sequence running. However, in the
                    # generation stage, we will have `best_of` sequences
                    # running.
                    return best_of
803
            # At sampling stages, return the number of actual sequences
804
805
            # that are not finished yet.
            return self.num_unfinished_seqs()
806

807
808
809
810
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
811
812
        if status is None:
            return self.seqs
813
814
815
816

        if self.is_single_seq:
            return self.seqs if self.seqs[0].status == status else []

817
        return [seq for seq in self.seqs if seq.status == status]
818

819
820
821
822
823
824
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

825
    def get_unfinished_seqs(self) -> List[Sequence]:
826
827
828
        if self.is_single_seq:
            return self.seqs if not self.seqs[0].is_finished() else []

829
        return [seq for seq in self.seqs if not seq.is_finished()]
830

831
    def get_finished_seqs(self) -> List[Sequence]:
832
833
834
        if self.is_single_seq:
            return self.seqs if self.seqs[0].is_finished() else []

835
        return [seq for seq in self.seqs if seq.is_finished()]
836

837
838
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
839
        for seq in self.seqs:
840
841
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
842
843

    def get_num_uncomputed_tokens(self) -> int:
844
        num_uncomputed_tokens = 0
845
        for seq in self.seqs:
846
847
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
848
        return num_uncomputed_tokens
849

850
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
851
852
853
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
854
            return len(self.seqs)
855

856
857
858
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

859
        return len(self.get_seqs(status))
860

861
    def num_unfinished_seqs(self) -> int:
862
863
864
        if self.is_single_seq:
            return 1 if not self.seqs[0].is_finished() else 0

865
866
867
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
868
869
870
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0

871
872
        return len(self.get_finished_seqs())

873
    def find(self, seq_id: int) -> Sequence:
874
875
876
877
878
879
880
881
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq
882
        self.seqs.append(seq)
883
        self.is_single_seq = len(self.seqs) == 1
884
885

    def remove(self, seq_id: int) -> None:
886
887
        seq = self.seqs_dict.pop(seq_id, None)
        if seq is None:
888
            raise ValueError(f"Sequence {seq_id} not found.")
889
        self.seqs.remove(seq)
890
        self.is_single_seq = len(self.seqs) == 1
Woosuk Kwon's avatar
Woosuk Kwon committed
891

Woosuk Kwon's avatar
Woosuk Kwon committed
892
    def is_finished(self) -> bool:
893
894
895
        if self.is_single_seq:
            return self.seqs[0].is_finished()

896
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
897

898
    def is_prefill(self) -> bool:
899
        # Every sequence should be in the same stage.
900
        return self.seqs[0].is_prefill()
901

Woosuk Kwon's avatar
Woosuk Kwon committed
902
    def __repr__(self) -> str:
903
904
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
905
                f"num_seqs={len(self.seqs)})")
906
907


908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
    seq_data_delta: Dict[int, SequenceDataDelta]
    request_id: str
    block_tables: Dict[int, List[int]]
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
934
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
935
936
937
938
939
940
941
942

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
943
944
945
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
946
947
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
948
        lora_request: LoRA request.
949
950
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
951
        state: Internal state tied to this sequence group.
952
        multi_modal_data: Multi modal data.
953
954
955
956
957
958
959
960
961
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
962
        prompt_adapter_request: Prompt Adapter request.
963
    """
964

965
966
967
    request_id: str
    is_prompt: bool
    seq_data: Dict[int, SequenceData]
968
    sampling_params: Optional[SamplingParams]
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
    block_tables: Dict[int, List[int]]
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
    # "MultiModalDataDict" types. We have to use Any due to msgspec
    # doesn't allow to have union of 2 different dicts.
    multi_modal_data: Optional[Any] = None
    encoder_seq_data: Optional[SequenceData] = None
    cross_block_table: Optional[List[int]] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
996
            else:
997
                self.token_chunk_size = 1
998

999
1000
1001
1002
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1003
    @property
1004
1005
1006
1007
1008
1009
1010
1011
1012
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

1013
1014
1015
1016
1017
1018
1019
1020
1021
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1022

1023
    def finish_step(self) -> None:
1024
        assert self.state is not None
1025
1026
1027
        assert self.state.current_step < self.state.num_steps
        self.state.current_step += 1

1028

1029
1030
1031
1032
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1033
1034
1035
1036
1037
1038
1039
1040
1041
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1042
1043
1044
    parent_seq_id: int
    output_token: int
    logprobs: Dict[int, Logprob]
1045
1046

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
1047
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1048
1049
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1050

1051
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1052
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1053
            raise NotImplementedError()
1054
1055
1056
1057
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1058
1059


1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1072
1073
1074
1075
1076
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
    __metaclass__ = SequenceGroupOutput
1077
    """The model output associated with a completion sequence group."""
1078
1079
1080
    samples: List[SequenceOutput]
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1081
1082

    def __repr__(self) -> str:
1083
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1084
1085
                f"prompt_logprobs={self.prompt_logprobs})")

1086
    def __eq__(self, other: object) -> bool:
1087
        if not isinstance(other, CompletionSequenceGroupOutput):
1088
1089
1090
1091
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1092

1093
1094
1095
1096
1097
class EmbeddingSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1098
    """The model output associated with an embedding sequence group."""
1099
1100
    __metaclass__ = SequenceGroupOutput
    embeddings: List[int]
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


1112
1113
1114
1115
class IntermediateTensors(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1142
1143
1144
1145
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1146
1147
1148
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

1149
    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1165
1166
1167
1168
1169
1170
1171
1172
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1188
1189
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1190
1191
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1192
    the target model to the proposer model.
1193
1194
1195

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1196
1197
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1198
    hidden_states: torch.Tensor
1199
1200
1201
1202
1203
1204
1205
1206
    # The sequence group metadata list. Only needed for decode step.
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1207
1208
1209
    _seq_ids: List[int] = msgspec.field(default_factory=list)

    def __post_init__(self):
1210
1211
1212
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1213
1214
1215
1216

    @property
    def seq_ids(self) -> List[int]:
        return self._seq_ids
1217

1218
1219
1220
1221
1222
1223
    def update(self,
               hidden_states: torch.Tensor,
               seq_group_metadata_list: List[SequenceGroupMetadata],
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1224
        assert len(seq_group_metadata_list) == len(hidden_states)
1225
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1226
1227
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1228
1229
1230
1231
1232
1233
1234
1235
1236
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1237
1238
    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
1239
1240
1241
1242
1243
1244
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1245
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1246
        if seq_ids != self._seq_ids:
1247
            # Batch contents changed - prune removed sequences.
1248
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1249
            self.hidden_states = self.hidden_states[index]
1250
1251
1252
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1253
            self._seq_ids = seq_ids
1254

1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1273

1274
1275
1276
1277
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1278
1279
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1280
    # The sequence group metadata list.
1281
1282
    seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta]]
1283
    # Blocks to swap in. List of CPU -> GPU block number.
1284
1285
    blocks_to_swap_in: List[Tuple[int,
                                  int]] = msgspec.field(default_factory=list)
1286
    # Blocks to swap out. List of GPU -> CPU block number.
1287
1288
    blocks_to_swap_out: List[Tuple[int,
                                   int]] = msgspec.field(default_factory=list)
1289
    # Blocks to copy. Source to dest block.
1290
    blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
1291
1292
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1293
1294
1295
1296
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1297
1298
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1299
1300
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1301
    # Finished request ids since last step.
1302
    finished_requests_ids: List[str] = msgspec.field(default_factory=list)
1303
1304
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1305
1306
    # Async callback
    async_callback: Optional[Callable] = None
1307
1308
1309
1310
1311
1312
1313

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1314
        assert first_seq_group.state is not None
1315
1316
1317
1318
1319
1320
1321
1322
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1323
        assert first_seq_group.state is not None
1324
        return first_seq_group.state.remaining_steps == 1
1325
1326
1327
1328
1329
1330

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1331
1332
1333
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1334
1335

    def clone(
1336
1337
        self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                  SequenceGroupMetadataDelta]]
1338
1339
1340
1341
1342
1343
1344
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1345
            virtual_engine=self.virtual_engine,
1346
1347
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1348
            previous_hidden_states=self.previous_hidden_states,
1349
            num_steps=self.num_steps,
1350
1351
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1352
            if self.last_sampled_token_ids is not None else None,
1353
            async_callback=self.async_callback)