sequence.py 56.3 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from abc import ABC, abstractmethod
5
from array import array
6
from collections import defaultdict
7
from dataclasses import dataclass
8
from functools import cached_property, reduce
9
10
11
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
import msgspec
14
15
import torch

16
from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs
17
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
18
from vllm.lora.request import LoRARequest
19
from vllm.pooling_params import PoolingParams
20
from vllm.prompt_adapter.request import PromptAdapterRequest
21
from vllm.sampling_params import SamplingParams
22
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
if TYPE_CHECKING:
25
    from vllm.multimodal.base import MultiModalDataDict
26

27
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
28

29
30
VLLM_INVALID_TOKEN_ID = -1

31

32
33
34
35
36
def array_full(token_id: int, count: int):
    """:class:`array` equivalent of :func:`numpy.full`."""
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


37
38
39
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
40
41
@dataclass
class Logprob:
42
43
44
45
46
47
48
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
49
    logprob: float
50
    rank: Optional[int] = None
51
52
53
    decoded_token: Optional[str] = None


54
55
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
56
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
57
# {token_id -> logprob} for each sequence group.
58
SampleLogprobs = List[Dict[int, Logprob]]
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
class SequenceStatus(enum.IntEnum):
62
    """Status of a sequence."""
63
64
65
66
67
68
69
70
71
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
75
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
76
77
78
79
80
81
82

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
83
84
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
85
        elif status == SequenceStatus.FINISHED_IGNORED:
86
87
88
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
89
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
90
91
92
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
93

94

95
96
97
98
99
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


100
101
102
103
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

104
    Attributes:
105
106
107
108
109
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
110
111
112
113
114
115
116
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
117
118
119
120
121
122
123
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
124
125
126
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
127
128


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
    new_output_token_ids: List[int]
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
146
147
148
149
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
150
151
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
152
153
154
155
156
157

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    # NOTE: we cannot use Union[List, array] because msgspec cannot support
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
    _prompt_token_ids_tuple: Tuple[int,
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
    _stage: SequenceStage = SequenceStage.PREFILL
    _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
    _new_appended_tokens: List[int] = msgspec.field(default_factory=list)

177
178
179
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

180
181
    _first_step_flag: bool = True

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    @staticmethod
    def from_prompt_token_counts(
            *token_counts: Tuple[int, int]) -> "SequenceData":
        """
        Construct a :class:`SequenceData` instance by concatenating
        prompt token sequences.

        Each tuple represents one token sequence, expressed in the form
        :code:`(token_id, count)`.
        """
        if len(token_counts) == 0:
            return SequenceData.from_seqs([])

        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )

        return SequenceData(prompt_token_ids_arr)
    
202
    @staticmethod
203
204
    def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
        if len(token_counts) == 0:
205
206
207
208
            return SequenceData.from_seqs([])

        arrs = [
            array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
209
            for token_id, count in token_counts
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        ]

        return SequenceData(reduce(array.__add__, arrs))

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
    ) -> "SequenceData":
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
            return SequenceData(prompt_token_ids_arr)

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
                            _output_token_ids=output_token_ids_arr)

231
232
233
234
235
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
            self._prompt_token_ids)
236
237
238
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
239
240
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
241
242
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
243

244
245
246
247
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

248
249
250
251
252
253
    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
254
        raise NotImplementedError
255

256
257
    @property
    def prompt_token_ids_array(self) -> array:
258
259
260
261
262
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
263
264
        return self._prompt_token_ids

265
266
267
268
269
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
270
271
272
    def output_token_ids(self, new_output_token_ids: List[int]) -> None:
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
273
274
        self._update_cached_all_tokens()

275
276
    @property
    def output_token_ids_array(self) -> array:
277
278
279
280
281
282
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
283
284
        return self._output_token_ids

285
286
287
288
289
290
291
292
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

293
    def append_token_id(self, token_id: int, logprob: float) -> None:
294
        self._output_token_ids.append(token_id)
295
        self._new_appended_tokens.append(token_id)
296
        self._cached_all_token_ids.append(token_id)
297
        self._cumulative_logprob += logprob
298
299

    def get_len(self) -> int:
300
        return len(self._output_token_ids) + len(self._prompt_token_ids)
301

302
    def get_prompt_len(self) -> int:
303
        return len(self._prompt_token_ids)
304

305
    def get_output_len(self) -> int:
306
        return len(self._output_token_ids)
307

308
    def get_token_ids(self) -> List[int]:
309
        return self._cached_all_token_ids
310

311
312
313
314
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
315
        prompt_length = self.get_prompt_len()
316
317
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
318
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
319
320
321
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

322
323
324
325
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

326
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
327
328
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
329
330
331
332
333
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
334

335
    def reset_state_for_recompute(self) -> None:
336
337
338
339
340
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
341
        self._stage = SequenceStage.PREFILL
342
        self._new_appended_tokens = []
343
344

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
345
        """Return the number of prefill tokens that are not computed."""
346
347
348
349
350
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

351
    def get_last_token_id(self) -> int:
352
353
354
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
355

356
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
357
358
        return self.prompt_token_ids

359
    def get_output_token_ids(self) -> Tuple[int, ...]:
360
361
        return self.output_token_ids

362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

377
378
379
    @property
    def stage(self) -> SequenceStage:
        return self._stage
380
381
382
383
384
385
    
    def get_first_step_flag(self):
        return self._first_step_flag
    
    def set_first_step_flag(self, flag: bool):
        self._first_step_flag = flag
386

387
388
    def __repr__(self) -> str:
        return (f"SequenceData("
389
                f"prompt_token_ids={self._prompt_token_ids}, "
390
391
392
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
                f"get_num_computed_tokens={self.get_num_computed_tokens()}")
393
394


Woosuk Kwon's avatar
Woosuk Kwon committed
395
class Sequence:
396
397
    """Stores the data, status, and block information of a sequence.

398
399
400
401
402
403
404
405
406
    The sequence is constructed from the LLMInputs instance passed
    in through the `inputs` constructor argument.

    For encoder/decoder models, LLMInputs encapsulates both a
    decoder and encoder prompt, creating an ambiguity about which
    prompt to construct the sequence from. The `from_decoder_prompt`
    constructor argument signals whether to construct the Sequence
    from the LLMInputs decoder prompt, or encoder prompt.

407
408
    Args:
        seq_id: The ID of the sequence.
409
        inputs: The inputs of the sequence.
410
411
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
412
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
413
        lora_request: LoRA request.
414
        prompt_adapter_request: Prompt Adapter request.
415
416
417
        from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
                             (True) or encoder prompt (False.) Must be True
                             for decoder-only model.
418

419
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
420
421

    def __init__(
422
423
424
425
426
427
428
429
        self,
        seq_id: int,
        inputs: "LLMInputs",
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        from_decoder_prompt: bool = True,
Woosuk Kwon's avatar
Woosuk Kwon committed
430
431
    ) -> None:
        self.seq_id = seq_id
432
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
433
        self.block_size = block_size
434
        self.eos_token_id = eos_token_id
435
        self.lora_request = lora_request
436
        self.prompt_adapter_request = prompt_adapter_request
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        self.from_decoder_prompt = from_decoder_prompt

        # For decoder-only models, a Sequence is constructed
        # from an LLMInputs instance (the `inputs` arg.)
        #
        # For encoder/decoder models the same `inputs`
        # instance could be utilized to construct either an
        # encoder sequence or a decoder sequence, because
        # `LLMInputs` has both decoder- and encoder-oriented
        # member variables (i.e. it encapsulates both an encoder
        # and a decoder prompt.) The decision of which type of sequence
        # to generate is determined by the `from_decoder_prompt` argument.
        #
        # When constructing a encoder sequence
        # (`from_decoder_prompt` False) it matters that
        # the `LLMInputs` instance stored in `inputs` is valid
        # in the sense that its encoder-related member variables are
        # populated; below, an exception is raised if this is
        # not the case.
        #
        # When constructing a decoder sequence (`from_decoder_prompt` True)
        # it does not matter whether `inputs` has its encoder-related
        # member variables populated.
        if not (from_decoder_prompt
                or is_valid_encoder_decoder_llm_inputs(inputs)):
            raise ValueError("Cannot extract encoder input prompt from "
                             f"invalid input {inputs}; did you forget the "
                             "encoder input prompt fields?")
Woosuk Kwon's avatar
Woosuk Kwon committed
465

466
        self.data = SequenceData.from_seqs(self.prompt_token_ids)
467
        self.output_logprobs: SampleLogprobs = []
468
        self.output_text = ""
469

470
        self.status = SequenceStatus.WAITING
471
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
472

473
        # These are used to keep track of delta outputs
474
        self._last_output_token_ids_offset: int = 0
475
476
        self._last_output_text_offset: int = 0

477
478
479
480
481
482
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

483
484
    @property
    def n_blocks(self) -> int:
485
        return (self.get_len() + self.block_size - 1) // self.block_size
486

487
    @cached_property
488
    def prompt(self) -> Optional[str]:
489
        # Select decoder or encoder input prompt str, as appropriate
490
491
492
        prompt_key: str = ("prompt"
                           if self.from_decoder_prompt else "encoder_prompt")

493
        return cast(Optional[str], self.inputs.get(prompt_key))
494

495
    @cached_property
496
    def prompt_token_ids(self) -> List[int]:
497
        # Select decoder or encoder input prompt token ids, as appropriate
498
499
500
501
502
        prompt_token_ids_key: str = ("prompt_token_ids"
                                     if self.from_decoder_prompt else
                                     "encoder_prompt_token_ids")

        # Cache computed prompt token ids
503
        return cast(List[int], self.inputs.get(prompt_token_ids_key))
504
505

    @property
506
    def multi_modal_data(self) -> "MultiModalDataDict":
507
508
509
510
511
512
513
514
515
        if self.inputs.get("multi_modal_data") and self.inputs.get(
                "encoder_multi_modal_data"):
            raise ValueError(
                "Multi-modal data in both encoder and decoder is not supported."
            )
        inputs = self.inputs
        return self.inputs.get("multi_modal_data") or (cast(
            EncoderDecoderLLMInputs,
            inputs).get("encoder_multi_modal_data")) or {}
516

517
518
519
520
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

521
522
523
524
525
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

526
527
528
529
530
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

531
532
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
533
534
535
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
536
537
538
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
539
540
541
542
543
544
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

545
546
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
547
548
549
550
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
551
552
553
554
555
556
557
558
559
560
561
562
563
564

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

        return self.data._cached_all_token_ids[-num_new_tokens:]
565

566
    def hash_of_block(self, logical_idx: int) -> int:
567
568
        # TODO This can produce incorrect hash when block size > prompt size

569
        # Compute the number of tokens in the sequence
570
571
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
572
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
573
574
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
575
576
577
578

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

579
580
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
581
        self.data.reset_state_for_recompute()
582

583
584
    def append_token_id(self, token_id: int, logprobs: Dict[int,
                                                            Logprob]) -> None:
585
586
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
587
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
588

Woosuk Kwon's avatar
Woosuk Kwon committed
589
    def get_len(self) -> int:
590
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
591

592
593
594
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

595
596
597
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
598
    def get_token_ids(self) -> List[int]:
599
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
600

601
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
602
603
        return self.data.get_prompt_token_ids()

604
    def get_last_token_id(self) -> int:
605
        return self.data.get_last_token_id()
606

607
608
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
609
610
611
612

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

613
    def get_beam_search_score(self,
614
                              length_penalty: float = 1.0,
615
616
617
618
619
620
621
622
623
624
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
625
            # NOTE: HF implementation does not count the EOS token
626
627
628
629
630
631
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

632
633
634
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

635
636
637
638
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
639

640
641
642
643
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
644
645
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
646
647
648
649
650
651
652
653
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
654
    def __repr__(self) -> str:
655
656
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
657
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
658

Woosuk Kwon's avatar
Woosuk Kwon committed
659

660
661
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
662
663
664
665
666
667
668
669
670
671
672
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
673
class SequenceGroup:
674
675
676
677
678
679
680
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
681
        lora_request: LoRA request.
682
683
684
685
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
686
687
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
688
        trace_headers: OpenTelemetry trace headers.
689
        prompt_adapter_request: Prompt Adapter request.
690
        priority: User-defined priority of the request.
691
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
692
693
694

    def __init__(
        self,
695
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
696
        seqs: List[Sequence],
697
        arrival_time: float,
698
        sampling_params: Optional[SamplingParams] = None,
699
        lora_request: Optional[LoRARequest] = None,
700
701
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
702
        encoder_seq: Optional[Sequence] = None,
703
        trace_headers: Optional[Mapping[str, str]] = None,
704
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
705
        priority: int = 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
706
    ) -> None:
707
        self.request_id = request_id
708
        self.seqs = seqs
709
        self.arrival_time = arrival_time
710
        self.is_single_seq = len(seqs) == 1
711
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
712

713
        self.sampling_params = sampling_params
714
715
716
717
718
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
719
        self.lora_request = lora_request
720
        self.prompt_logprobs: Optional[PromptLogprobs] = None
721
        self.state = SequenceGroupState()
722
723
        self.embeddings = embeddings
        self.pooling_params = pooling_params
724
        self.prompt_adapter_request = prompt_adapter_request
725
        self.encoder_seq = encoder_seq
726
        self.trace_headers = trace_headers
727
        self.priority = priority
728

729
730
        self.cached_request_output = None

731
    @property
732
    def prompt(self) -> Optional[str]:
733
734
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
735
        return self.seqs[0].prompt
736
737
738
739
740

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
741
        return self.seqs[0].prompt_token_ids
742

743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

759
    @property
760
    def multi_modal_data(self) -> "MultiModalDataDict":
761
762
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
763
        return self.seqs[0].multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
764

765
766
767
768
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

769
770
771
772
773
774
775
776
777
778
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

779
780
781
782
    def init_multi_step(self, num_scheduler_steps: int) -> None:
        self.state.num_steps = num_scheduler_steps
        self.state.current_step = 0

783
784
785
786
787
788
789
790
791
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
792
793
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
794
795
        return latency

796
797
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
798
799
800
801
802
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
803
                and self.seqs[0].get_output_len() == 1):
804
805
806
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
807
808
        """Sets the first scheduled time and time in queue for Request
        level timings."""
809
810
811
812
813
814
815
816
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

817
818
819
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
820
        if self.sampling_params and self.sampling_params.use_beam_search:
821
822
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
823
824
825
            best_of = self.sampling_params.best_of
            assert isinstance(best_of, int)
            return best_of
826
        else:
827
828
829
830
831
832
833
834
835
            if self.sampling_params:
                best_of = self.sampling_params.best_of
                assert isinstance(best_of, int)
                if best_of > self.num_seqs():
                    # At prompt stage, the sequence group is not yet filled up
                    # and only have one sequence running. However, in the
                    # generation stage, we will have `best_of` sequences
                    # running.
                    return best_of
836
            # At sampling stages, return the number of actual sequences
837
838
            # that are not finished yet.
            return self.num_unfinished_seqs()
839

840
841
842
843
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
844
845
        if status is None:
            return self.seqs
846
847
848
849

        if self.is_single_seq:
            return self.seqs if self.seqs[0].status == status else []

850
        return [seq for seq in self.seqs if seq.status == status]
851

852
853
854
855
856
857
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

858
    def get_unfinished_seqs(self) -> List[Sequence]:
859
860
861
        if self.is_single_seq:
            return self.seqs if not self.seqs[0].is_finished() else []

862
        return [seq for seq in self.seqs if not seq.is_finished()]
863

864
    def get_finished_seqs(self) -> List[Sequence]:
865
866
867
        if self.is_single_seq:
            return self.seqs if self.seqs[0].is_finished() else []

868
        return [seq for seq in self.seqs if seq.is_finished()]
869

870
871
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
872
        for seq in self.seqs:
873
874
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
875
876

    def get_num_uncomputed_tokens(self) -> int:
877
        num_uncomputed_tokens = 0
878
        for seq in self.seqs:
879
880
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
881
        return num_uncomputed_tokens
882

883
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
884
885
886
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
887
            return len(self.seqs)
888

889
890
891
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

892
        return len(self.get_seqs(status))
893

894
    def num_unfinished_seqs(self) -> int:
895
896
897
        if self.is_single_seq:
            return 1 if not self.seqs[0].is_finished() else 0

898
899
900
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
901
902
903
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0

904
905
        return len(self.get_finished_seqs())

906
    def find(self, seq_id: int) -> Sequence:
907
908
909
910
911
912
913
914
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq
915
        self.seqs.append(seq)
916
        self.is_single_seq = len(self.seqs) == 1
917
918

    def remove(self, seq_id: int) -> None:
919
920
        seq = self.seqs_dict.pop(seq_id, None)
        if seq is None:
921
            raise ValueError(f"Sequence {seq_id} not found.")
922
        self.seqs.remove(seq)
923
        self.is_single_seq = len(self.seqs) == 1
Woosuk Kwon's avatar
Woosuk Kwon committed
924

Woosuk Kwon's avatar
Woosuk Kwon committed
925
    def is_finished(self) -> bool:
926
927
928
        if self.is_single_seq:
            return self.seqs[0].is_finished()

929
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
930

931
    def is_prefill(self) -> bool:
932
        # Every sequence should be in the same stage.
933
        return self.seqs[0].is_prefill()
934

Woosuk Kwon's avatar
Woosuk Kwon committed
935
    def __repr__(self) -> str:
936
937
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
938
                f"num_seqs={len(self.seqs)})")
939
940


941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
    seq_data_delta: Dict[int, SequenceDataDelta]
    request_id: str
    block_tables: Dict[int, List[int]]
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
967
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
968
969
970
971
972
973
974
975

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
976
977
978
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
979
980
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
981
        lora_request: LoRA request.
982
983
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
984
        state: Internal state tied to this sequence group.
985
        multi_modal_data: Multi modal data.
986
987
988
989
990
991
992
993
994
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
995
        prompt_adapter_request: Prompt Adapter request.
996
    """
997

998
999
1000
    request_id: str
    is_prompt: bool
    seq_data: Dict[int, SequenceData]
1001
    sampling_params: Optional[SamplingParams]
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
    block_tables: Dict[int, List[int]]
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
    # "MultiModalDataDict" types. We have to use Any due to msgspec
    # doesn't allow to have union of 2 different dicts.
    multi_modal_data: Optional[Any] = None
    encoder_seq_data: Optional[SequenceData] = None
    cross_block_table: Optional[List[int]] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
1029
            else:
1030
                self.token_chunk_size = 1
1031

1032
1033
1034
1035
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1036
    @property
1037
1038
1039
1040
1041
1042
1043
1044
1045
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

1046
1047
1048
1049
1050
1051
1052
1053
1054
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1055

1056
    def finish_step(self) -> None:
1057
        assert self.state is not None
1058
1059
1060
        assert self.state.current_step < self.state.num_steps
        self.state.current_step += 1

1061

1062
1063
1064
1065
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1066
1067
1068
1069
1070
1071
1072
1073
1074
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1075
1076
1077
    parent_seq_id: int
    output_token: int
    logprobs: Dict[int, Logprob]
1078
1079

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
1080
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1081
1082
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1083

1084
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1085
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1086
            raise NotImplementedError()
1087
1088
1089
1090
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1091
1092


1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1105
1106
1107
1108
1109
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
    __metaclass__ = SequenceGroupOutput
1110
    """The model output associated with a completion sequence group."""
1111
1112
1113
    samples: List[SequenceOutput]
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1114
1115

    def __repr__(self) -> str:
1116
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1117
1118
                f"prompt_logprobs={self.prompt_logprobs})")

1119
    def __eq__(self, other: object) -> bool:
1120
        if not isinstance(other, CompletionSequenceGroupOutput):
1121
1122
1123
1124
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1125

1126
1127
1128
1129
1130
class EmbeddingSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1131
    """The model output associated with an embedding sequence group."""
1132
1133
    __metaclass__ = SequenceGroupOutput
    embeddings: List[int]
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


1145
1146
1147
1148
class IntermediateTensors(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1175
1176
1177
1178
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1179
1180
1181
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

1182
    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1198
1199
1200
1201
1202
1203
1204
1205
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1221
1222
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1223
1224
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1225
    the target model to the proposer model.
1226
1227
1228

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1229
1230
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1231
    hidden_states: torch.Tensor
1232
1233
1234
1235
1236
1237
1238
1239
    # The sequence group metadata list. Only needed for decode step.
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1240
1241
1242
    _seq_ids: List[int] = msgspec.field(default_factory=list)

    def __post_init__(self):
1243
1244
1245
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1246
1247
1248
1249

    @property
    def seq_ids(self) -> List[int]:
        return self._seq_ids
1250

1251
1252
1253
1254
1255
1256
    def update(self,
               hidden_states: torch.Tensor,
               seq_group_metadata_list: List[SequenceGroupMetadata],
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1257
        assert len(seq_group_metadata_list) == len(hidden_states)
1258
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1259
1260
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1261
1262
1263
1264
1265
1266
1267
1268
1269
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1270
1271
    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
1272
1273
1274
1275
1276
1277
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1278
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1279
        if seq_ids != self._seq_ids:
1280
            # Batch contents changed - prune removed sequences.
1281
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1282
            self.hidden_states = self.hidden_states[index]
1283
1284
1285
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1286
            self._seq_ids = seq_ids
1287

1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1306

1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
class Logits(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
    """Logits corresponding to in-progress sequences.
    Used in speculative decoding to pass lm_head logits from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the logits tensor"""
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
    logits: torch.Tensor
    # The sequence group metadata list. Only needed for decode step.
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None

    _seq_ids: List[int] = msgspec.field(default_factory=list)

    def __post_init__(self):
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.logits)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)

    @property
    def seq_ids(self) -> List[int]:
        return self._seq_ids
    
    def update(self,
               logits: torch.Tensor,
               seq_group_metadata_list: List[SequenceGroupMetadata]):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
        assert len(seq_group_metadata_list) == len(logits)
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.logits = torch.cat([self.logits, logits])

    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self._seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
            self.logits = self.logits[index]
            self._seq_ids = seq_ids


1357
1358
1359
1360
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1361
1362
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1363
    # The sequence group metadata list.
1364
1365
    seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta]]
1366
    # Blocks to swap in. List of CPU -> GPU block number.
1367
1368
    blocks_to_swap_in: List[Tuple[int,
                                  int]] = msgspec.field(default_factory=list)
1369
    # Blocks to swap out. List of GPU -> CPU block number.
1370
1371
    blocks_to_swap_out: List[Tuple[int,
                                   int]] = msgspec.field(default_factory=list)
1372
    # Blocks to copy. Source to dest block.
1373
    blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
1374
1375
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1376
1377
1378
1379
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1380
1381
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1382
1383
    # Optional logits from prior step.
    previous_logits: Optional[Logits] = None
1384
1385
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1386
    # Finished request ids since last step.
1387
    finished_requests_ids: List[str] = msgspec.field(default_factory=list)
1388
1389
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1390
1391
    # Async callback
    async_callback: Optional[Callable] = None
1392

1393
1394
1395
1396
1397
1398
    # Optional tree attention mask from draft model.
    tree_attn_masks: Optional[torch.Tensor] = None

    # Optional tree position ids from draft model.
    tree_position_ids: Optional[torch.Tensor] = None

1399
1400
1401
    # Optional slot mapping of kvcache that pending to be moved generated from draft model.
    kvcache_slot_to_be_moved: Optional[torch.Tensor] = None

1402
1403
1404
1405
1406
1407
    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1408
        assert first_seq_group.state is not None
1409
1410
1411
1412
1413
1414
1415
1416
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1417
        assert first_seq_group.state is not None
1418
        return first_seq_group.state.remaining_steps == 1
1419
1420
1421
1422
1423
1424

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1425
1426
1427
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1428
1429

    def clone(
1430
1431
        self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                  SequenceGroupMetadataDelta]]
1432
1433
1434
1435
1436
1437
1438
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1439
            virtual_engine=self.virtual_engine,
1440
1441
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1442
            previous_hidden_states=self.previous_hidden_states,
1443
            previous_logits=self.previous_logits,
1444
            num_steps=self.num_steps,
1445
1446
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1447
            if self.last_sampled_token_ids is not None else None,
1448
1449
            async_callback=self.async_callback,
            tree_attn_masks=self.tree_attn_masks,
1450
1451
            tree_position_ids=self.tree_position_ids,
            kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)