"tests/vscode:/vscode.git/clone" did not exist on "0ddf88e16e6ef4d985716f5bdec60fd053a260fa"
sequence.py 56.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Sequence and its related classes."""
3
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import enum
5
from abc import ABC, abstractmethod
6
from array import array
7
from collections import defaultdict
8
from dataclasses import dataclass, field
9
10
from functools import reduce
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
11
from typing import Sequence as GenericSequence
12
from typing import Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
import msgspec
15
16
import torch

17
from vllm.inputs import SingletonInputs, SingletonInputsAdapter
18
from vllm.lora.request import LoRARequest
19
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
20
from vllm.pooling_params import PoolingParams
21
from vllm.prompt_adapter.request import PromptAdapterRequest
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
25

26
27
VLLM_INVALID_TOKEN_ID = -1

28

29
30
31
32
33
def array_full(token_id: int, count: int):
    """:class:`array` equivalent of :func:`numpy.full`."""
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


34
35
36
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
37
38
@dataclass
class Logprob:
39
40
41
42
43
44
45
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
46
    logprob: float
47
    rank: Optional[int] = None
48
49
50
    decoded_token: Optional[str] = None


51
52
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
53
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
54
# {token_id -> logprob} for each sequence group.
55
SampleLogprobs = List[Dict[int, Logprob]]
56

Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
class SequenceStatus(enum.IntEnum):
59
    """Status of a sequence."""
60
61
62
63
64
65
66
67
68
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
72
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
73
74
75
76
77
78
79

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
80
81
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
82
        elif status == SequenceStatus.FINISHED_IGNORED:
83
84
85
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
86
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
87
88
89
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
90

91

92
93
94
95
96
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


97
98
99
100
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

101
    Attributes:
102
103
104
105
106
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
107
108
109
110
111
112
113
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
114
115
116
117
118
119
120
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
121
122
123
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
124
125


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
    new_output_token_ids: List[int]
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
143
144
145
146
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
147
148
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
149
150
151
152
153
154

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
155
156
157
158
159
160
161
162
163
164
165
166
    # NOTE: we cannot use Union[List, array] because msgspec cannot support
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
    _prompt_token_ids_tuple: Tuple[int,
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
167
168
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
169
170
171
172
173
174
175
    _stage: SequenceStage = SequenceStage.PREFILL
    _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
    _new_appended_tokens: List[int] = msgspec.field(default_factory=list)

176
177
178
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

179
    @staticmethod
180
181
182
183
184
185
186
187
188
    def from_prompt_token_counts(
            *token_counts: Tuple[int, int]) -> "SequenceData":
        """
        Construct a :class:`SequenceData` instance by concatenating
        prompt token sequences.

        Each tuple represents one token sequence, expressed in the form
        :code:`(token_id, count)`.
        """
189
        if len(token_counts) == 0:
190
191
            return SequenceData.from_seqs([])

192
193
194
195
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
196

197
        return SequenceData(prompt_token_ids_arr)
198
199
200
201
202
203

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
    ) -> "SequenceData":
204
205
206
207
        """
        Construct a :class:`SequenceData` instance from prompt and output
        token sequences.
        """
208
209
210
211
212
213
214
215
216
217
218
219
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
            return SequenceData(prompt_token_ids_arr)

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
                            _output_token_ids=output_token_ids_arr)

220
221
222
223
224
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
            self._prompt_token_ids)
225
226
227
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
228
229
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
230
231
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
232

233
234
235
236
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

237
238
239
240
241
242
    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
243
        raise NotImplementedError
244

245
246
    @property
    def prompt_token_ids_array(self) -> array:
247
248
249
250
251
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
252
253
        return self._prompt_token_ids

254
255
256
257
258
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
259
260
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
261
262
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
263
264
        self._update_cached_all_tokens()

265
266
    @property
    def output_token_ids_array(self) -> array:
267
268
269
270
271
272
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
273
274
        return self._output_token_ids

275
276
277
278
279
280
281
282
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

283
    def append_token_id(self, token_id: int, logprob: float) -> None:
284
        self._output_token_ids.append(token_id)
285
        self._new_appended_tokens.append(token_id)
286
        self._cached_all_token_ids.append(token_id)
287
        self._cumulative_logprob += logprob
288
289

    def get_len(self) -> int:
290
        return len(self._output_token_ids) + len(self._prompt_token_ids)
291

292
    def get_prompt_len(self) -> int:
293
        return len(self._prompt_token_ids)
294

295
    def get_output_len(self) -> int:
296
        return len(self._output_token_ids)
297

298
    def get_token_ids(self) -> List[int]:
299
        return self._cached_all_token_ids
300

301
302
303
304
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
305
        prompt_length = self.get_prompt_len()
306
307
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
308
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
309
310
311
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

312
313
314
315
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

316
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
317
318
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
319
320
321
322
323
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
324

325
326
327
328
329
330
331
332
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

333
    def reset_state_for_recompute(self) -> None:
334
335
336
337
338
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
339
        self._stage = SequenceStage.PREFILL
340
        self._new_appended_tokens = []
341
342

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
343
        """Return the number of prefill tokens that are not computed."""
344
345
346
347
348
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

349
    def get_last_token_id(self) -> int:
350
351
352
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
353

354
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
355
356
        return self.prompt_token_ids

357
    def get_output_token_ids(self) -> Tuple[int, ...]:
358
359
        return self.output_token_ids

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

375
376
377
378
    @property
    def stage(self) -> SequenceStage:
        return self._stage

379
380
    def __repr__(self) -> str:
        return (f"SequenceData("
381
                f"prompt_token_ids={self._prompt_token_ids}, "
382
383
384
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
                f"get_num_computed_tokens={self.get_num_computed_tokens()}")
385
386


Woosuk Kwon's avatar
Woosuk Kwon committed
387
class Sequence:
388
    """Stores the data, status, and block information of a sequence.
389

390
391
392
    The sequence is constructed from the :data:`DecoderOnlyInputs`
    (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
    instance passed in through the :code:`inputs` constructor argument.
393

394
395
    Args:
        seq_id: The ID of the sequence.
396
        inputs: The inputs of the sequence.
397
398
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
399
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
400
        lora_request: LoRA request.
401
        prompt_adapter_request: Prompt Adapter request.
402
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
403
404

    def __init__(
405
406
        self,
        seq_id: int,
407
        inputs: SingletonInputs,
408
409
410
411
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
    ) -> None:
        self.seq_id = seq_id
414
        self.inputs = SingletonInputsAdapter(inputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
415
        self.block_size = block_size
416
        self.eos_token_id = eos_token_id
417
        self.lora_request = lora_request
418
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
419

420
        self.data = SequenceData.from_seqs(self.prompt_token_ids)
421
        self.output_logprobs: SampleLogprobs = []
422
        self.output_text = ""
423

424
        self.status = SequenceStatus.WAITING
425
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
426

427
        # These are used to keep track of delta outputs
428
        self._last_output_token_ids_offset: int = 0
429
430
        self._last_output_text_offset: int = 0

431
432
433
434
435
436
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

437
438
    @property
    def n_blocks(self) -> int:
439
        return (self.get_len() + self.block_size - 1) // self.block_size
440

441
    @property
442
    def prompt(self) -> Optional[str]:
443
        return self.inputs.prompt
444

445
    @property
446
    def prompt_token_ids(self) -> List[int]:
447
        return self.inputs.prompt_token_ids
448

449
    @property
450
    def prompt_embeds(self) -> Optional[torch.Tensor]:
451
        return self.inputs.prompt_embeds
452

453
454
455
456
    @property
    def token_type_ids(self) -> List[int]:
        return self.inputs.token_type_ids

457
    @property
458
    def multi_modal_data(self) -> "MultiModalDataDict":
459
        return self.inputs.multi_modal_data
460

461
462
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
463
        return self.inputs.multi_modal_placeholders
464

465
466
467
    @property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
        return self.inputs.mm_processor_kwargs
468

469
470
471
472
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

473
474
475
476
477
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

478
479
480
481
482
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

483
484
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
485
486
487
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
488
489
490
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
491
492
493
494
495
496
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

497
498
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
499
500
501
502
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
503
504
505
506
507
508
509
510
511
512
513
514
515

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

516
517
518
        if num_new_tokens == 0:
            return []

519
        return self.data._cached_all_token_ids[-num_new_tokens:]
520

521
    def hash_of_block(self, logical_idx: int) -> int:
522
523
        # TODO This can produce incorrect hash when block size > prompt size

524
        # Compute the number of tokens in the sequence
525
526
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
527
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
528
529
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
530

531
532
533
534
535
536
537
538
539
540
541
542
543
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
        if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
        return hash((self.prompt_adapter_id, self.lora_int_id))

544
545
546
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

547
548
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
549
        self.data.reset_state_for_recompute()
550

551
552
    def append_token_id(self, token_id: int, logprobs: Dict[int,
                                                            Logprob]) -> None:
553
554
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
555
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
556

Woosuk Kwon's avatar
Woosuk Kwon committed
557
    def get_len(self) -> int:
558
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
559

560
561
562
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

563
564
565
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
566
    def get_token_ids(self) -> List[int]:
567
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
568

569
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
570
571
        return self.data.get_prompt_token_ids()

572
    def get_last_token_id(self) -> int:
573
        return self.data.get_last_token_id()
574

575
576
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
577
578
579
580

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

581
582
583
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

584
585
586
587
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
588

589
590
591
592
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
593
594
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
595
596
597
598
599
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

600
601
602
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

603
604
605
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
606
    def __repr__(self) -> str:
607
608
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
609
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
610

Woosuk Kwon's avatar
Woosuk Kwon committed
611

612
613
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
614
615
616
617
618
619
620
621
622
623
624
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
625
class SequenceGroup:
626
627
628
629
630
631
632
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
633
        lora_request: LoRA request.
634
        pooling_params: The parameters used to generate the pooler
635
            for a pooling model.
636
        pooled_data: The extracted hidden states from a pooling model.
637
638
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
639
        trace_headers: OpenTelemetry trace headers.
640
        prompt_adapter_request: Prompt Adapter request.
641
        priority: User-defined priority of the request.
642
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
643
644
645

    def __init__(
        self,
646
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
647
        seqs: List[Sequence],
648
        arrival_time: float,
649
        sampling_params: Optional[SamplingParams] = None,
650
        lora_request: Optional[LoRARequest] = None,
651
        pooling_params: Optional[PoolingParams] = None,
652
        pooled_data: Optional[torch.Tensor] = None,
653
        encoder_seq: Optional[Sequence] = None,
654
        trace_headers: Optional[Mapping[str, str]] = None,
655
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
656
        priority: int = 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
657
    ) -> None:
658
        self.request_id = request_id
659
        self.seqs = seqs
660
        self.first_seq = seqs[0]
661
        self.arrival_time = arrival_time
662
        self.is_single_seq = len(seqs) == 1
663
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
664

665
        self.sampling_params = sampling_params
666
667
668
669
670
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
671
        self.last_token_latency = 0.0
672
        self.lora_request = lora_request
673
        self.prompt_logprobs: Optional[PromptLogprobs] = None
674
        self.state = SequenceGroupState()
675
        self.pooling_params = pooling_params
676
        self.pooled_data = pooled_data
677
        self.prompt_adapter_request = prompt_adapter_request
678
        self.encoder_seq = encoder_seq
679
        self.trace_headers = trace_headers
680
        self.priority = priority
681

682
683
        self.cached_request_output = None

684
    @property
685
    def prompt(self) -> Optional[str]:
686
        return self.first_seq.prompt
687
688
689

    @property
    def prompt_token_ids(self) -> List[int]:
690
        return self.first_seq.prompt_token_ids
691

692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

708
709
710
711
    @property
    def token_type_ids(self) -> Optional[List[int]]:
        return self.first_seq.token_type_ids

712
    @property
713
    def multi_modal_data(self) -> MultiModalDataDict:
714
715
716
717
718
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
        return {}
Woosuk Kwon's avatar
Woosuk Kwon committed
719

720
721
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
722
723
724
725
726
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
727

728
729
    @property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
730
731
732
733
734
        if self.first_seq.multi_modal_data:
            return self.first_seq.mm_processor_kwargs
        elif self.encoder_seq is not None:
            return self.encoder_seq.mm_processor_kwargs
        return {}
735

736
737
738
739
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

740
741
742
743
744
745
746
747
748
749
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

750
751
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
752
753
        self.state.current_step = 0

754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

779
    def set_last_token_time(self, now: float) -> None:
780
        """Sets the last token time for Request level timings."""
781
782
783
784
785
        # If still in prefill phase, assertion fails.
        assert not self.is_prefill(), (
            "seq_group.set_last_token_time() should not be called "
            "if the seq_group is in prefill phase.")
        self.last_token_latency = now - self.metrics.last_token_time
786
        self.metrics.last_token_time = now
787
788
789
790
791
792
793

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
        assert not self.is_prefill(), (
            "seq_group.get_last_token_latency() should not be called "
            "if the seq_group is in prefill phase.")
        return self.last_token_latency
794

795
796
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
797
798
799
800
801
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
802
                and self.first_seq.get_output_len() == 1):
803
804
805
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
806
807
        """Sets the first scheduled time and time in queue for Request
        level timings."""
808
809
810
811
812
813
814
815
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

816
817
818
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
819
820
821
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
822

823
824
825
826
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
827
828
        if status is None:
            return self.seqs
829

830
831
832
833
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
834

835
836
837
838
839
840
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

841
    def get_finished_seqs(self) -> List[Sequence]:
842
843
844
845
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
846

847
848
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
849
850
851
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
852
853

    def get_num_uncomputed_tokens(self) -> int:
854
        num_uncomputed_tokens = 0
855
856
857
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
858
        return num_uncomputed_tokens
859

860
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
861
862
863
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
864
            return len(self.seqs)
865

866
867
868
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

869
        return len(self.get_seqs(status))
870

871
    def num_finished_seqs(self) -> int:
872
873
874
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
875

Woosuk Kwon's avatar
Woosuk Kwon committed
876
    def is_finished(self) -> bool:
877
878
879
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
880

881
    def is_prefill(self) -> bool:
882
        return self.first_seq.is_prefill()
883

Woosuk Kwon's avatar
Woosuk Kwon committed
884
    def __repr__(self) -> str:
885
886
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
887
                f"num_seqs={len(self.seqs)})")
888
889


890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
    seq_data_delta: Dict[int, SequenceDataDelta]
    request_id: str
    block_tables: Dict[int, List[int]]
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
916
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
917
918
919
920
921
922
923
924

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
925
926
927
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
928
929
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
930
        lora_request: LoRA request.
931
932
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
933
        state: Internal state tied to this sequence group.
934
        multi_modal_data: Multi modal data.
935
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
936
        encoder_seq_data: Optional sequence data for encoder prompt
937
                          (SequenceGroup.encoder_seq). Should be None
938
939
940
941
942
943
944
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
945
        prompt_adapter_request: Prompt Adapter request.
946
    """
947

948
949
950
    request_id: str
    is_prompt: bool
    seq_data: Dict[int, SequenceData]
951
    sampling_params: Optional[SamplingParams]
952
953
954
955
956
957
958
959
960
    block_tables: Dict[int, List[int]]
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
    # "MultiModalDataDict" types. We have to use Any due to msgspec
    # doesn't allow to have union of 2 different dicts.
961
    token_type_ids: Optional[List[int]] = None
962
    multi_modal_data: Optional[Any] = None
963
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
964
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
    encoder_seq_data: Optional[SequenceData] = None
    cross_block_table: Optional[List[int]] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
982
            else:
983
                self.token_chunk_size = 1
984

985
986
987
988
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

989
    @property
990
991
992
993
994
995
996
997
998
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1013
1014
1015
1016
1017
1018
1019
1020
1021
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1022

1023
    def finish_step(self) -> None:
1024
        assert self.state is not None
1025
1026
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1027
1028
        self.state.current_step += 1

1029

1030
1031
1032
1033
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1034
1035
1036
1037
1038
1039
1040
1041
1042
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1043
1044
1045
    parent_seq_id: int
    output_token: int
    logprobs: Dict[int, Logprob]
1046
1047

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
1048
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1049
1050
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1051

1052
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1053
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1054
            raise NotImplementedError()
1055
1056
1057
1058
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1059
1060


1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1073
1074
1075
1076
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1077
    """The model output associated with a completion sequence group."""
1078
    __metaclass__ = SequenceGroupOutput
1079
1080
1081
    samples: List[SequenceOutput]
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1082
1083

    def __repr__(self) -> str:
1084
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1085
1086
                f"prompt_logprobs={self.prompt_logprobs})")

1087
    def __eq__(self, other: object) -> bool:
1088
        if not isinstance(other, CompletionSequenceGroupOutput):
1089
1090
1091
1092
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1093

1094
class PoolingSequenceGroupOutput(
1095
1096
1097
1098
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1099
    """The model output associated with a pooling sequence group."""
1100
    __metaclass__ = SequenceGroupOutput
1101
1102
1103
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1104
1105

    def __repr__(self) -> str:
1106
        return f"PoolingSequenceGroupOutput(data={self.data}"
1107
1108

    def __eq__(self, other: object) -> bool:
1109
        if not isinstance(other, PoolingSequenceGroupOutput):
1110
            raise NotImplementedError()
1111
        return self.data == other.data
1112
1113


1114
1115
1116
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1117
1118
1119
1120
1121
1122
1123
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

1124
1125
1126
1127
1128
1129
1130
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1131
1132
1133
1134
1135
1136
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1137
    def __setitem__(self, key: str, value: torch.Tensor):
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1150
1151
1152
1153
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1154
    """The output from a pooling operation in the pooling model."""
1155
    outputs: List[PoolingSequenceGroupOutput]
1156

1157
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1158
1159
        return self.outputs[idx]

1160
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1171
1172
1173
1174
1175
1176
1177
1178
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1179
1180
1181
1182
1183
1184
1185
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
1186
    request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set)
1187
1188
1189
1190
1191
1192
1193
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1194
1195
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1196
1197
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1198
    the target model to the proposer model.
1199
1200
1201

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1202
1203
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1204
    hidden_states: torch.Tensor
1205
1206
1207
1208
1209
1210
1211
1212
    # The sequence group metadata list. Only needed for decode step.
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1213
1214
1215
    _seq_ids: List[int] = msgspec.field(default_factory=list)

    def __post_init__(self):
1216
1217
1218
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1219
1220
1221
1222

    @property
    def seq_ids(self) -> List[int]:
        return self._seq_ids
1223

1224
1225
1226
1227
1228
1229
    def update(self,
               hidden_states: torch.Tensor,
               seq_group_metadata_list: List[SequenceGroupMetadata],
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1230
        assert len(seq_group_metadata_list) == len(hidden_states)
1231
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1232
1233
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1234
1235
1236
1237
1238
1239
1240
1241
1242
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1243
1244
    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
1245
1246
1247
1248
1249
1250
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1251
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1252
        if seq_ids != self._seq_ids:
1253
            # Batch contents changed - prune removed sequences.
1254
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1255
            self.hidden_states = self.hidden_states[index]
1256
1257
1258
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1259
            self._seq_ids = seq_ids
1260

1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1279

1280
1281
1282
1283
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1284
1285
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1286
    # The sequence group metadata list.
1287
1288
    seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta]]
1289
    # Blocks to swap in. List of CPU -> GPU block number.
1290
1291
    blocks_to_swap_in: List[Tuple[int,
                                  int]] = msgspec.field(default_factory=list)
1292
    # Blocks to swap out. List of GPU -> CPU block number.
1293
1294
    blocks_to_swap_out: List[Tuple[int,
                                   int]] = msgspec.field(default_factory=list)
1295
    # Blocks to copy. Source to dest block.
1296
    blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
1297
1298
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1299
1300
1301
1302
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1303
1304
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1305
1306
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1307
    # Finished request ids since last step.
1308
    finished_requests_ids: List[str] = msgspec.field(default_factory=list)
1309
1310
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1311
1312
    # Async callback
    async_callback: Optional[Callable] = None
1313
1314
1315
1316
1317
1318
1319

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1320
        assert first_seq_group.state is not None
1321
1322
1323
1324
1325
1326
1327
1328
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1329
        assert first_seq_group.state is not None
1330
        return first_seq_group.state.remaining_steps == 1
1331
1332
1333
1334
1335
1336

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1337
1338
1339
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1340
1341

    def clone(
1342
1343
        self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                  SequenceGroupMetadataDelta]]
1344
1345
1346
1347
1348
1349
1350
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1351
            virtual_engine=self.virtual_engine,
1352
1353
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1354
            previous_hidden_states=self.previous_hidden_states,
1355
            num_steps=self.num_steps,
1356
1357
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1358
            if self.last_sampled_token_ids is not None else None,
1359
            async_callback=self.async_callback)
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
    seq_id_to_index: Dict[str, int] = field(default_factory=dict)

    # seq ids to be finished
    to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)

    # seq id to finished sequences
    finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1412
1413
1414
1415
            params = copy.deepcopy(original_params)
            params.n = 1
            if params.seed is not None:
                params.seed += i
1416
            seq_group = engine._add_processed_request(
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1436
            pooled_data=seq_group.pooled_data,
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            prompt_adapter_request=seq_group.prompt_adapter_request,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1450
1451
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1452
        if self.streaming:
1453
1454
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1455
1456
1457
1458
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1459
        # when the last sequences finishes, and then return None for the
1460
        # rest of the time
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None