sequence.py 51.9 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from abc import ABC, abstractmethod
5
from array import array
6
from collections import defaultdict
7
from dataclasses import dataclass
8
from functools import cached_property, reduce
9
10
11
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
import msgspec
14
15
import torch

16
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
17
from vllm.lora.request import LoRARequest
18
from vllm.pooling_params import PoolingParams
19
from vllm.prompt_adapter.request import PromptAdapterRequest
20
from vllm.sampling_params import SamplingParams
21
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
if TYPE_CHECKING:
24
    from vllm.inputs import LLMInputs
25
    from vllm.multimodal.base import MultiModalDataDict
26

27
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
28

29
30
31
32

# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
33
34
@dataclass
class Logprob:
35
36
37
38
39
40
41
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
42
    logprob: float
43
    rank: Optional[int] = None
44
45
46
    decoded_token: Optional[str] = None


47
48
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
49
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
50
# {token_id -> logprob} for each sequence group.
51
SampleLogprobs = List[Dict[int, Logprob]]
52

Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
class SequenceStatus(enum.IntEnum):
55
    """Status of a sequence."""
56
57
58
59
60
61
62
63
64
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
65
66
67

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
68
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71
72
73
74
75

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
76
77
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
78
        elif status == SequenceStatus.FINISHED_IGNORED:
79
80
81
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
82
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
83
84
85
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
86

87

88
89
90
91
92
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


93
94
95
96
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

97
    Attributes:
98
99
100
101
102
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
103
104
105
106
107
108
109
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
110
111
112
113
114
115
116
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
117
118
119
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
120
121


122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
    new_output_token_ids: List[int]
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
139
140
141
142
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
143
144
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
145
146
147
148
149
150

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    # NOTE: we cannot use Union[List, array] because msgspec cannot support
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
    _prompt_token_ids_tuple: Tuple[int,
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
    _stage: SequenceStage = SequenceStage.PREFILL
    _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
    _new_appended_tokens: List[int] = msgspec.field(default_factory=list)

170
171
172
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

173
    @staticmethod
174
175
    def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
        if len(token_counts) == 0:
176
177
178
179
            return SequenceData.from_seqs([])

        arrs = [
            array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
180
            for token_id, count in token_counts
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        ]

        return SequenceData(reduce(array.__add__, arrs))

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
    ) -> "SequenceData":
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
            return SequenceData(prompt_token_ids_arr)

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
                            _output_token_ids=output_token_ids_arr)

202
203
204
205
206
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
            self._prompt_token_ids)
207
208
209
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
210
211
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
212
213
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
214

215
216
217
218
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

219
220
221
222
223
224
    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
225
        raise NotImplementedError
226

227
228
    @property
    def prompt_token_ids_array(self) -> array:
229
230
231
232
233
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
234
235
        return self._prompt_token_ids

236
237
238
239
240
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
241
242
243
    def output_token_ids(self, new_output_token_ids: List[int]) -> None:
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
244
245
        self._update_cached_all_tokens()

246
247
    @property
    def output_token_ids_array(self) -> array:
248
249
250
251
252
253
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
254
255
        return self._output_token_ids

256
257
258
259
260
261
262
263
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

264
    def append_token_id(self, token_id: int, logprob: float) -> None:
265
        self._output_token_ids.append(token_id)
266
        self._new_appended_tokens.append(token_id)
267
        self._cached_all_token_ids.append(token_id)
268
        self._cumulative_logprob += logprob
269
270

    def get_len(self) -> int:
271
        return len(self._output_token_ids) + len(self._prompt_token_ids)
272

273
    def get_prompt_len(self) -> int:
274
        return len(self._prompt_token_ids)
275

276
    def get_output_len(self) -> int:
277
        return len(self._output_token_ids)
278

279
    def get_token_ids(self) -> List[int]:
280
        return self._cached_all_token_ids
281

282
283
284
285
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
286
        prompt_length = self.get_prompt_len()
287
288
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
289
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
290
291
292
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

293
294
295
296
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

297
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
298
299
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
300
301
302
303
304
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
305

306
    def reset_state_for_recompute(self) -> None:
307
308
309
310
311
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
312
        self._stage = SequenceStage.PREFILL
313
        self._new_appended_tokens = []
314
315

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
316
        """Return the number of prefill tokens that are not computed."""
317
318
319
320
321
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

322
    def get_last_token_id(self) -> int:
323
324
325
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
326

327
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
328
329
        return self.prompt_token_ids

330
    def get_output_token_ids(self) -> Tuple[int, ...]:
331
332
        return self.output_token_ids

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

348
349
350
351
    @property
    def stage(self) -> SequenceStage:
        return self._stage

352
353
    def __repr__(self) -> str:
        return (f"SequenceData("
354
                f"prompt_token_ids={self._prompt_token_ids}, "
355
356
357
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
                f"get_num_computed_tokens={self.get_num_computed_tokens()}")
358
359


Woosuk Kwon's avatar
Woosuk Kwon committed
360
class Sequence:
361
362
    """Stores the data, status, and block information of a sequence.

363
364
365
366
367
368
369
370
371
    The sequence is constructed from the LLMInputs instance passed
    in through the `inputs` constructor argument.

    For encoder/decoder models, LLMInputs encapsulates both a
    decoder and encoder prompt, creating an ambiguity about which
    prompt to construct the sequence from. The `from_decoder_prompt`
    constructor argument signals whether to construct the Sequence
    from the LLMInputs decoder prompt, or encoder prompt.

372
373
    Args:
        seq_id: The ID of the sequence.
374
        inputs: The inputs of the sequence.
375
376
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
377
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
378
        lora_request: LoRA request.
379
        prompt_adapter_request: Prompt Adapter request.
380
381
382
        from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
                             (True) or encoder prompt (False.) Must be True
                             for decoder-only model.
383

384
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
385
386

    def __init__(
387
388
389
390
391
392
393
394
        self,
        seq_id: int,
        inputs: "LLMInputs",
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        from_decoder_prompt: bool = True,
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
    ) -> None:
        self.seq_id = seq_id
397
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
398
        self.block_size = block_size
399
        self.eos_token_id = eos_token_id
400
        self.lora_request = lora_request
401
        self.prompt_adapter_request = prompt_adapter_request
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        self.from_decoder_prompt = from_decoder_prompt

        # For decoder-only models, a Sequence is constructed
        # from an LLMInputs instance (the `inputs` arg.)
        #
        # For encoder/decoder models the same `inputs`
        # instance could be utilized to construct either an
        # encoder sequence or a decoder sequence, because
        # `LLMInputs` has both decoder- and encoder-oriented
        # member variables (i.e. it encapsulates both an encoder
        # and a decoder prompt.) The decision of which type of sequence
        # to generate is determined by the `from_decoder_prompt` argument.
        #
        # When constructing a encoder sequence
        # (`from_decoder_prompt` False) it matters that
        # the `LLMInputs` instance stored in `inputs` is valid
        # in the sense that its encoder-related member variables are
        # populated; below, an exception is raised if this is
        # not the case.
        #
        # When constructing a decoder sequence (`from_decoder_prompt` True)
        # it does not matter whether `inputs` has its encoder-related
        # member variables populated.
        if not (from_decoder_prompt
                or is_valid_encoder_decoder_llm_inputs(inputs)):
            raise ValueError("Cannot extract encoder input prompt from "
                             f"invalid input {inputs}; did you forget the "
                             "encoder input prompt fields?")
Woosuk Kwon's avatar
Woosuk Kwon committed
430

431
        self.data = SequenceData.from_seqs(self.prompt_token_ids)
432
        self.output_logprobs: SampleLogprobs = []
433
        self.output_text = ""
434

435
        self.status = SequenceStatus.WAITING
436
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
437

438
        # These are used to keep track of delta outputs
439
        self._last_output_token_ids_offset: int = 0
440
441
        self._last_output_text_offset: int = 0

442
443
444
445
446
447
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

448
449
    @property
    def n_blocks(self) -> int:
450
        return (self.get_len() + self.block_size - 1) // self.block_size
451

452
    @cached_property
453
    def prompt(self) -> Optional[str]:
454
        # Select decoder or encoder input prompt str, as appropriate
455
456
457
        prompt_key: str = ("prompt"
                           if self.from_decoder_prompt else "encoder_prompt")

458
        return cast(Optional[str], self.inputs.get(prompt_key))
459

460
    @cached_property
461
    def prompt_token_ids(self) -> List[int]:
462
        # Select decoder or encoder input prompt token ids, as appropriate
463
464
465
466
467
        prompt_token_ids_key: str = ("prompt_token_ids"
                                     if self.from_decoder_prompt else
                                     "encoder_prompt_token_ids")

        # Cache computed prompt token ids
468
        return cast(List[int], self.inputs.get(prompt_token_ids_key))
469
470

    @property
471
472
    def multi_modal_data(self) -> "MultiModalDataDict":
        return self.inputs.get("multi_modal_data") or {}
473

474
475
476
477
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

478
479
480
481
482
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

483
484
485
486
487
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

488
489
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
490
491
492
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
493
494
495
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
496
497
498
499
500
501
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

502
503
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
504
505
506
507
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
508
509
510
511
512
513
514
515
516
517
518
519
520
521

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

        return self.data._cached_all_token_ids[-num_new_tokens:]
522

523
    def hash_of_block(self, logical_idx: int) -> int:
524
525
        # TODO This can produce incorrect hash when block size > prompt size

526
        # Compute the number of tokens in the sequence
527
528
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
529
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
530
531
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
532
533
534
535

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

536
537
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
538
        self.data.reset_state_for_recompute()
539

540
541
    def append_token_id(self, token_id: int, logprobs: Dict[int,
                                                            Logprob]) -> None:
542
543
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
544
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
545

Woosuk Kwon's avatar
Woosuk Kwon committed
546
    def get_len(self) -> int:
547
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
548

549
550
551
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

552
553
554
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
555
    def get_token_ids(self) -> List[int]:
556
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
557

558
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
559
560
        return self.data.get_prompt_token_ids()

561
    def get_last_token_id(self) -> int:
562
        return self.data.get_last_token_id()
563

564
565
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
566
567
568
569

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

570
    def get_beam_search_score(self,
571
                              length_penalty: float = 1.0,
572
573
574
575
576
577
578
579
580
581
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
582
            # NOTE: HF implementation does not count the EOS token
583
584
585
586
587
588
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

589
590
591
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

592
593
594
595
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
596

597
598
599
600
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
601
602
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
603
604
605
606
607
608
609
610
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
611
    def __repr__(self) -> str:
612
613
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
614
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
615

Woosuk Kwon's avatar
Woosuk Kwon committed
616

617
618
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
619
620
621
622
623
624
625
626
627
628
629
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
630
class SequenceGroup:
631
632
633
634
635
636
637
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
638
        lora_request: LoRA request.
639
640
641
642
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
643
644
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
645
        trace_headers: OpenTelemetry trace headers.
646
        prompt_adapter_request: Prompt Adapter request.
647
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
648
649
650

    def __init__(
        self,
651
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
652
        seqs: List[Sequence],
653
        arrival_time: float,
654
        sampling_params: Optional[SamplingParams] = None,
655
        lora_request: Optional[LoRARequest] = None,
656
657
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
658
        encoder_seq: Optional[Sequence] = None,
659
        trace_headers: Optional[Mapping[str, str]] = None,
660
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
661
    ) -> None:
662
        self.request_id = request_id
663
        self.seqs = seqs
664
        self.is_single_seq = len(seqs) == 1
665
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
666

667
        self.sampling_params = sampling_params
668
669
670
671
672
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
673
        self.lora_request = lora_request
674
        self.prompt_logprobs: Optional[PromptLogprobs] = None
675
        self.state = SequenceGroupState()
676
677
        self.embeddings = embeddings
        self.pooling_params = pooling_params
678
        self.prompt_adapter_request = prompt_adapter_request
679
        self.encoder_seq = encoder_seq
680
        self.trace_headers = trace_headers
681

682
683
        self.cached_request_output = None

684
    @property
685
    def prompt(self) -> Optional[str]:
686
687
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
688
        return self.seqs[0].prompt
689
690
691
692
693

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
694
        return self.seqs[0].prompt_token_ids
695

696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

712
    @property
713
    def multi_modal_data(self) -> "MultiModalDataDict":
714
715
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
716
        return self.seqs[0].multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
717

718
719
720
721
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

722
723
724
725
726
727
728
729
730
731
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

732
733
734
735
    def init_multi_step(self, num_scheduler_steps: int) -> None:
        self.state.num_steps = num_scheduler_steps
        self.state.current_step = 0

736
737
738
739
740
741
742
743
744
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
745
746
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
747
748
        return latency

749
750
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
751
752
753
754
755
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
756
                and self.seqs[0].get_output_len() == 1):
757
758
759
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
760
761
        """Sets the first scheduled time and time in queue for Request
        level timings."""
762
763
764
765
766
767
768
769
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

770
771
772
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
773
        if self.sampling_params and self.sampling_params.use_beam_search:
774
775
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
776
777
778
            best_of = self.sampling_params.best_of
            assert isinstance(best_of, int)
            return best_of
779
        else:
780
781
782
783
784
785
786
787
788
            if self.sampling_params:
                best_of = self.sampling_params.best_of
                assert isinstance(best_of, int)
                if best_of > self.num_seqs():
                    # At prompt stage, the sequence group is not yet filled up
                    # and only have one sequence running. However, in the
                    # generation stage, we will have `best_of` sequences
                    # running.
                    return best_of
789
            # At sampling stages, return the number of actual sequences
790
791
            # that are not finished yet.
            return self.num_unfinished_seqs()
792

793
794
795
796
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
797
798
        if status is None:
            return self.seqs
799
800
801
802

        if self.is_single_seq:
            return self.seqs if self.seqs[0].status == status else []

803
        return [seq for seq in self.seqs if seq.status == status]
804

805
806
807
808
809
810
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

811
    def get_unfinished_seqs(self) -> List[Sequence]:
812
813
814
        if self.is_single_seq:
            return self.seqs if not self.seqs[0].is_finished() else []

815
        return [seq for seq in self.seqs if not seq.is_finished()]
816

817
    def get_finished_seqs(self) -> List[Sequence]:
818
819
820
        if self.is_single_seq:
            return self.seqs if self.seqs[0].is_finished() else []

821
        return [seq for seq in self.seqs if seq.is_finished()]
822

823
824
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
825
        for seq in self.seqs:
826
827
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
828
829

    def get_num_uncomputed_tokens(self) -> int:
830
        num_uncomputed_tokens = 0
831
        for seq in self.seqs:
832
833
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
834
        return num_uncomputed_tokens
835

836
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
837
838
839
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
840
            return len(self.seqs)
841

842
843
844
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

845
        return len(self.get_seqs(status))
846

847
    def num_unfinished_seqs(self) -> int:
848
849
850
        if self.is_single_seq:
            return 1 if not self.seqs[0].is_finished() else 0

851
852
853
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
854
855
856
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0

857
858
        return len(self.get_finished_seqs())

859
    def find(self, seq_id: int) -> Sequence:
860
861
862
863
864
865
866
867
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq
868
        self.seqs.append(seq)
869
        self.is_single_seq = len(self.seqs) == 1
870
871

    def remove(self, seq_id: int) -> None:
872
873
        seq = self.seqs_dict.pop(seq_id, None)
        if seq is None:
874
            raise ValueError(f"Sequence {seq_id} not found.")
875
        self.seqs.remove(seq)
876
        self.is_single_seq = len(self.seqs) == 1
Woosuk Kwon's avatar
Woosuk Kwon committed
877

Woosuk Kwon's avatar
Woosuk Kwon committed
878
    def is_finished(self) -> bool:
879
880
881
        if self.is_single_seq:
            return self.seqs[0].is_finished()

882
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
883

884
    def is_prefill(self) -> bool:
885
        # Every sequence should be in the same stage.
886
        return self.seqs[0].is_prefill()
887

Woosuk Kwon's avatar
Woosuk Kwon committed
888
    def __repr__(self) -> str:
889
890
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
891
                f"num_seqs={len(self.seqs)})")
892
893


894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
    seq_data_delta: Dict[int, SequenceDataDelta]
    request_id: str
    block_tables: Dict[int, List[int]]
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
920
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
921
922
923
924
925
926
927
928

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
929
930
931
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
932
933
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
934
        lora_request: LoRA request.
935
936
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
937
        state: Internal state tied to this sequence group.
938
        multi_modal_data: Multi modal data.
939
940
941
942
943
944
945
946
947
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
948
        prompt_adapter_request: Prompt Adapter request.
949
    """
950

951
952
953
    request_id: str
    is_prompt: bool
    seq_data: Dict[int, SequenceData]
954
    sampling_params: Optional[SamplingParams]
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
    block_tables: Dict[int, List[int]]
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
    # "MultiModalDataDict" types. We have to use Any due to msgspec
    # doesn't allow to have union of 2 different dicts.
    multi_modal_data: Optional[Any] = None
    encoder_seq_data: Optional[SequenceData] = None
    cross_block_table: Optional[List[int]] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
982
            else:
983
                self.token_chunk_size = 1
984

985
986
987
988
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

989
    @property
990
991
992
993
994
995
996
997
998
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

999
1000
1001
1002
1003
1004
1005
1006
1007
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1008

1009
    def finish_step(self) -> None:
1010
        assert self.state is not None
1011
1012
1013
        assert self.state.current_step < self.state.num_steps
        self.state.current_step += 1

1014

1015
1016
1017
1018
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1019
1020
1021
1022
1023
1024
1025
1026
1027
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1028
1029
1030
    parent_seq_id: int
    output_token: int
    logprobs: Dict[int, Logprob]
1031
1032

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
1033
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1034
1035
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1036

1037
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1038
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1039
            raise NotImplementedError()
1040
1041
1042
1043
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1044
1045


1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1058
1059
1060
1061
1062
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
    __metaclass__ = SequenceGroupOutput
1063
    """The model output associated with a completion sequence group."""
1064
1065
1066
    samples: List[SequenceOutput]
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1067
1068

    def __repr__(self) -> str:
1069
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1070
1071
                f"prompt_logprobs={self.prompt_logprobs})")

1072
    def __eq__(self, other: object) -> bool:
1073
        if not isinstance(other, CompletionSequenceGroupOutput):
1074
1075
1076
1077
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1078

1079
1080
1081
1082
1083
class EmbeddingSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1084
    """The model output associated with an embedding sequence group."""
1085
1086
    __metaclass__ = SequenceGroupOutput
    embeddings: List[int]
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


1098
1099
1100
1101
class IntermediateTensors(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1128
1129
1130
1131
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1132
1133
1134
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

1135
    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1151
1152
1153
1154
1155
1156
1157
1158
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1174
1175
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1176
1177
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1178
    the target model to the proposer model.
1179
1180
1181

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1182
1183
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1184
    hidden_states: torch.Tensor
1185
1186
1187
1188
1189
1190
1191
1192
    # The sequence group metadata list. Only needed for decode step.
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1193
1194
1195
    _seq_ids: List[int] = msgspec.field(default_factory=list)

    def __post_init__(self):
1196
1197
1198
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1199
1200
1201
1202

    @property
    def seq_ids(self) -> List[int]:
        return self._seq_ids
1203

1204
1205
1206
1207
1208
1209
    def update(self,
               hidden_states: torch.Tensor,
               seq_group_metadata_list: List[SequenceGroupMetadata],
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1210
        assert len(seq_group_metadata_list) == len(hidden_states)
1211
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1212
1213
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1214
1215
1216
1217
1218
1219
1220
1221
1222
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1223
1224
    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
1225
1226
1227
1228
1229
1230
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1231
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1232
        if seq_ids != self._seq_ids:
1233
            # Batch contents changed - prune removed sequences.
1234
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1235
            self.hidden_states = self.hidden_states[index]
1236
1237
1238
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1239
            self._seq_ids = seq_ids
1240

1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1259

1260
1261
1262
1263
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1264
1265
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1266
    # The sequence group metadata list.
1267
1268
    seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta]]
1269
    # Blocks to swap in. List of CPU -> GPU block number.
1270
1271
    blocks_to_swap_in: List[Tuple[int,
                                  int]] = msgspec.field(default_factory=list)
1272
    # Blocks to swap out. List of GPU -> CPU block number.
1273
1274
    blocks_to_swap_out: List[Tuple[int,
                                   int]] = msgspec.field(default_factory=list)
1275
    # Blocks to copy. Source to dest block.
1276
    blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
1277
1278
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1279
1280
1281
1282
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1283
1284
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1285
1286
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1287
    # Finished request ids since last step.
1288
    finished_requests_ids: List[str] = msgspec.field(default_factory=list)
1289
1290
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1291
1292
    # Async callback
    async_callback: Optional[Callable] = None
1293
1294
1295
1296
1297
1298
1299

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1300
        assert first_seq_group.state is not None
1301
1302
1303
1304
1305
1306
1307
1308
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1309
        assert first_seq_group.state is not None
1310
        return first_seq_group.state.remaining_steps == 1
1311
1312
1313
1314
1315
1316

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1317
1318
1319
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1320
1321

    def clone(
1322
1323
        self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                  SequenceGroupMetadataDelta]]
1324
1325
1326
1327
1328
1329
1330
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1331
            virtual_engine=self.virtual_engine,
1332
1333
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1334
            previous_hidden_states=self.previous_hidden_states,
1335
            num_steps=self.num_steps,
1336
1337
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1338
            if self.last_sampled_token_ids is not None else None,
1339
            async_callback=self.async_callback)