sequence.py 62.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Sequence and its related classes."""
3
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import enum
5
from abc import ABC, abstractmethod
6
from array import array
7
from collections import defaultdict
8
9
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
10
from dataclasses import dataclass, field
11
from functools import reduce
12
from typing import Any, Callable, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
import msgspec
15
16
import torch

17
from vllm.inputs import SingletonInputs, SingletonInputsAdapter
18
from vllm.lora.request import LoRARequest
19
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
20
from vllm.pooling_params import PoolingParams
21
from vllm.prompt_adapter.request import PromptAdapterRequest
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
23

24
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
25

26
27
VLLM_INVALID_TOKEN_ID = -1

28

29
30
31
32
33
def array_full(token_id: int, count: int):
    """:class:`array` equivalent of :func:`numpy.full`."""
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


34
35
36
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
37
38
@dataclass
class Logprob:
39
40
41
42
43
44
45
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
46
    logprob: float
47
    rank: Optional[int] = None
48
49
50
    decoded_token: Optional[str] = None


51
52
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
53
PromptLogprobs = list[Optional[dict[int, Logprob]]]
54
# {token_id -> logprob} for each sequence group.
55
SampleLogprobs = list[dict[int, Logprob]]
56

Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
class SequenceStatus(enum.IntEnum):
59
    """Status of a sequence."""
60
61
62
63
64
65
66
67
68
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
72
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
73
74
75
76
77
78
79

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
80
81
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
82
        elif status == SequenceStatus.FINISHED_IGNORED:
83
84
85
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
86
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
87
88
89
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
90

91

92
93
94
95
96
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


97
98
99
100
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

101
    Attributes:
102
103
104
105
106
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
107
108
109
110
111
112
113
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
114
115
116
117
118
119
120
        spec_token_acceptance_counts: number of accepted speculative tokens at
                                      each position; the first token is from 
                                      the target model and is always accepted;
                                      e.g., when it's [10, 8, 4, 2] for a req, 
                                      it means there were 10 forward passes in
                                      total, and there were 8, 4, 2 accepted 
                                      tokens at 1st, 2nd, 3rd speculation step. 
121
122
123
124
125
126
127
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
128
129
130
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
131
    spec_token_acceptance_counts: Optional[list[int]] = None
132
133


134
135
136
137
138
139
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
140
    new_output_token_ids: list[int]
141
142
143
144
145
146
147
148
149
150
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
151
152
153
154
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
155
156
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
157
158
159
160
161
162

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
163
    # NOTE: we cannot use Union[list, array] because msgspec cannot support
164
165
166
167
168
169
170
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
171
    _prompt_token_ids_tuple: tuple[int,
172
173
174
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
175
176
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
177
    _stage: SequenceStage = SequenceStage.PREFILL
178
    _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
179
180
181

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
182
    _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
183

184
185
186
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

187
188
    _first_step_flag: bool = True

189
190
    @staticmethod
    def from_prompt_token_counts(
zhuwenwen's avatar
zhuwenwen committed
191
            *token_counts: tuple[int, int]) -> "SequenceData":
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        """
        Construct a :class:`SequenceData` instance by concatenating
        prompt token sequences.

        Each tuple represents one token sequence, expressed in the form
        :code:`(token_id, count)`.
        """
        if len(token_counts) == 0:
            return SequenceData.from_seqs([])

        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )

        return SequenceData(prompt_token_ids_arr)
    
209
    @staticmethod
210
    def from_prompt_token_counts(
211
            *token_counts: tuple[int, int]) -> "SequenceData":
212
213
214
215
216
217
218
        """
        Construct a :class:`SequenceData` instance by concatenating
        prompt token sequences.

        Each tuple represents one token sequence, expressed in the form
        :code:`(token_id, count)`.
        """
219
        if len(token_counts) == 0:
220
221
            return SequenceData.from_seqs([])

222
223
224
225
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
226

227
        return SequenceData(prompt_token_ids_arr)
228
229
230
231
232
233

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
    ) -> "SequenceData":
234
235
236
237
        """
        Construct a :class:`SequenceData` instance from prompt and output
        token sequences.
        """
238
239
240
241
242
243
244
245
246
247
248
249
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
            return SequenceData(prompt_token_ids_arr)

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
                            _output_token_ids=output_token_ids_arr)

250
251
252
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
253
        self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
254
            self._prompt_token_ids)
255
256
257
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
258
259
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
260
        self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
261
                                                     self._output_token_ids)
262

263
264
265
266
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

267
    @property
268
    def prompt_token_ids(self) -> tuple[int, ...]:
269
270
271
272
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
273
        raise NotImplementedError
274

275
276
    @property
    def prompt_token_ids_array(self) -> array:
277
278
279
280
281
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
282
283
        return self._prompt_token_ids

284
    @property
285
    def output_token_ids(self) -> tuple[int, ...]:
286
287
288
        return tuple(self._output_token_ids)

    @output_token_ids.setter
289
290
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
291
292
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
293
294
        self._update_cached_all_tokens()

295
296
    @property
    def output_token_ids_array(self) -> array:
297
298
299
300
301
302
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
303
304
        return self._output_token_ids

305
306
307
308
309
310
311
312
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

313
    def append_token_id(self, token_id: int, logprob: float) -> None:
314
        self._output_token_ids.append(token_id)
315
        self._new_appended_tokens.append(token_id)
316
        self._cached_all_token_ids.append(token_id)
317
        self._cumulative_logprob += logprob
318
319

    def get_len(self) -> int:
320
        return len(self._output_token_ids) + len(self._prompt_token_ids)
321

322
    def get_prompt_len(self) -> int:
323
        return len(self._prompt_token_ids)
324

325
    def get_output_len(self) -> int:
326
        return len(self._output_token_ids)
327

328
    def get_token_ids(self) -> list[int]:
329
        return self._cached_all_token_ids
330

331
332
    def get_prefix_token_ids(
            self, num_tokens: int
333
    ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
334
        """Get prefix tokens, and make the return value hashable"""
335
        prompt_length = self.get_prompt_len()
336
337
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
338
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
339
340
341
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

342
343
344
345
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

346
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
347
348
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
349
350
351
352
353
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
354

355
356
357
358
359
360
361
362
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

363
    def reset_state_for_recompute(self) -> None:
364
365
366
367
368
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
369
        self._stage = SequenceStage.PREFILL
370
        self._new_appended_tokens = []
371
372

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
373
        """Return the number of prefill tokens that are not computed."""
374
375
376
377
378
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

379
    def get_last_token_id(self) -> int:
380
381
382
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
383

384
    def get_prompt_token_ids(self) -> tuple[int, ...]:
385
386
        return self.prompt_token_ids

387
    def get_output_token_ids(self) -> tuple[int, ...]:
388
389
        return self.output_token_ids

390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

405
406
407
    @property
    def stage(self) -> SequenceStage:
        return self._stage
408
409
410
411
412
413
    
    def get_first_step_flag(self):
        return self._first_step_flag
    
    def set_first_step_flag(self, flag: bool):
        self._first_step_flag = flag
414

415
416
    def __repr__(self) -> str:
        return (f"SequenceData("
417
                f"prompt_token_ids={self._prompt_token_ids}, "
418
419
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
420
                f"get_num_computed_tokens={self.get_num_computed_tokens()})")
421
422


Woosuk Kwon's avatar
Woosuk Kwon committed
423
class Sequence:
424
425
    """Stores the data, status, and block information of a sequence.

426
427
428
    The sequence is constructed from the :data:`DecoderOnlyInputs`
    (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
    instance passed in through the :code:`inputs` constructor argument.
429

430
431
    Args:
        seq_id: The ID of the sequence.
432
        inputs: The inputs of the sequence.
433
434
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
435
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
436
        lora_request: LoRA request.
437
        prompt_adapter_request: Prompt Adapter request.
438
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
439
440

    def __init__(
441
442
        self,
        seq_id: int,
443
        inputs: SingletonInputs,
444
445
446
447
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
448
449
    ) -> None:
        self.seq_id = seq_id
450
        self.inputs = SingletonInputsAdapter(inputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
451
        self.block_size = block_size
452
        self.eos_token_id = eos_token_id
453
        self.lora_request = lora_request
454
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
455

456
        self.data = SequenceData.from_seqs(self.prompt_token_ids)
457
        self.output_logprobs: SampleLogprobs = []
458
        self.output_text = ""
459

460
        self.status = SequenceStatus.WAITING
461
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
462

463
        # These are used to keep track of delta outputs
464
        self._last_output_token_ids_offset: int = 0
465
466
        self._last_output_text_offset: int = 0

467
468
469
470
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
471
        self.tokens: Optional[list[str]] = None
472

473
474
    @property
    def n_blocks(self) -> int:
475
        return (self.get_len() + self.block_size - 1) // self.block_size
476

477
    @property
478
    def prompt(self) -> Optional[str]:
479
        return self.inputs.prompt
480

481
    @property
482
    def prompt_token_ids(self) -> list[int]:
483
        return self.inputs.prompt_token_ids
484

485
    @property
486
    def prompt_embeds(self) -> Optional[torch.Tensor]:
487
        return self.inputs.prompt_embeds
488

489
    @property
490
    def token_type_ids(self) -> list[int]:
491
        return self.inputs.token_type_ids
492
493

    @property
494
    def multi_modal_data(self) -> "MultiModalDataDict":
495
        return self.inputs.multi_modal_data
496

497
498
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
499
        return self.inputs.multi_modal_placeholders
500

501
    @property
502
    def mm_processor_kwargs(self) -> dict[str, Any]:
503
        return self.inputs.mm_processor_kwargs
504

505
506
507
508
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

509
510
511
512
513
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

514
515
516
517
518
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

519
520
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
521
522
523
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
524
525
526
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
527
528
529
530
531
532
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

533
534
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
535
536
537
538
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
539
540
541
542
543
544
545
546
547
548
549
550
551

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

552
553
554
        if num_new_tokens == 0:
            return []

555
        return self.data._cached_all_token_ids[-num_new_tokens:]
556

557
    def hash_of_block(self, logical_idx: int) -> int:
558
559
        # TODO This can produce incorrect hash when block size > prompt size

560
        # Compute the number of tokens in the sequence
561
562
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
563
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
564
565
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
566

567
568
569
570
571
572
573
574
575
576
577
578
579
    def extra_hash(self) -> Optional[int]:
        """
        This function computes an extra hash for a sequence, specifically
        designed for prefix caching mode. The final sequence hash is determined
        by applying token_ids from the sequence's blocks.
        """
        if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
            return None

        # NOTE: If there are additional factors influencing the block aside from
        # token_ids, include them as input parameters to the hash.
        return hash((self.prompt_adapter_id, self.lora_int_id))

580
581
582
    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

583
584
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
585
        self.data.reset_state_for_recompute()
586

587
    def append_token_id(self, token_id: int, logprobs: dict[int,
588
                                                            Logprob]) -> None:
589
590
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
591
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
592

Woosuk Kwon's avatar
Woosuk Kwon committed
593
    def get_len(self) -> int:
594
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
595

596
597
598
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

599
600
601
    def get_output_len(self) -> int:
        return self.data.get_output_len()

602
    def get_token_ids(self) -> list[int]:
603
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
604

605
    def get_prompt_token_ids(self) -> tuple[int, ...]:
606
607
        return self.data.get_prompt_token_ids()

608
    def get_last_token_id(self) -> int:
609
        return self.data.get_last_token_id()
610

611
    def get_output_token_ids(self) -> tuple[int, ...]:
612
        return self.data.get_output_token_ids()
613
614
615
616

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

617
618
619
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

620
621
622
623
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
624

625
626
627
628
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
629
630
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
631
632
633
634
635
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

636
637
638
    def get_num_computed_tokens(self) -> int:
        return self.data.get_num_computed_tokens()

639
640
641
    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
642
    def __repr__(self) -> str:
643
644
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
645
                f"num_blocks={self.n_blocks})")
Woosuk Kwon's avatar
Woosuk Kwon committed
646

Woosuk Kwon's avatar
Woosuk Kwon committed
647

648
649
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
650
651
652
653
654
655
656
657
658
659
660
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
661
class SequenceGroup:
662
663
664
665
666
667
668
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
669
        lora_request: LoRA request.
670
        pooling_params: The parameters used to generate the pooler
671
            for a pooling model.
672
        pooled_data: The extracted hidden states from a pooling model.
673
674
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
675
        trace_headers: OpenTelemetry trace headers.
676
        prompt_adapter_request: Prompt Adapter request.
677
        priority: User-defined priority of the request.
678
679
680
681
        draft_size: The number of speculative tokens plus one from the target 
                    model; equal to max number of tokens a step can generate
                    for single-draft speculative decoding but larger than 
                    that for multi-draft SD (currently not supported).
682
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
683

684
685
686
687
688
689
690
691
692
693
694
695
696
    def __init__(self,
                 request_id: str,
                 seqs: list[Sequence],
                 arrival_time: float,
                 sampling_params: Optional[SamplingParams] = None,
                 lora_request: Optional[LoRARequest] = None,
                 pooling_params: Optional[PoolingParams] = None,
                 pooled_data: Optional[torch.Tensor] = None,
                 encoder_seq: Optional[Sequence] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 prompt_adapter_request: Optional[PromptAdapterRequest] = None,
                 priority: int = 0,
                 draft_size: int = 1) -> None:
697
        self.request_id = request_id
698
        self.seqs = seqs
699
        self.first_seq = seqs[0]
700
        self.arrival_time = arrival_time
701
        self.is_single_seq = len(seqs) == 1
702
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
703

704
        self.sampling_params = sampling_params
705
706
707
708
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
709
710
711
                                      time_in_queue=None,
                                      spec_token_acceptance_counts=[0] *
                                      draft_size)
712
        self.last_token_latency = 0.0
713
        self.lora_request = lora_request
714
        self.prompt_logprobs: Optional[PromptLogprobs] = None
715
        self.state = SequenceGroupState()
716
        self.pooling_params = pooling_params
717
        self.pooled_data = pooled_data
718
        self.prompt_adapter_request = prompt_adapter_request
719
        self.encoder_seq = encoder_seq
720
        self.trace_headers = trace_headers
721
        self.priority = priority
722

723
724
        self.cached_request_output = None

725
    @property
726
    def prompt(self) -> Optional[str]:
727
        return self.first_seq.prompt
728
729

    @property
730
    def prompt_token_ids(self) -> list[int]:
731
        return self.first_seq.prompt_token_ids
732

733
734
735
736
737
738
739
740
741
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
742
    def encoder_prompt_token_ids(self) -> Optional[list[int]]:
743
744
745
746
747
748
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

749
    @property
750
    def token_type_ids(self) -> Optional[list[int]]:
751
752
        return self.first_seq.token_type_ids

753
    @property
754
    def multi_modal_data(self) -> MultiModalDataDict:
755
756
757
758
759
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_data
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_data
        return {}
Woosuk Kwon's avatar
Woosuk Kwon committed
760

761
762
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
763
764
765
766
767
        if self.first_seq.multi_modal_data:
            return self.first_seq.multi_modal_placeholders
        elif self.encoder_seq is not None:
            return self.encoder_seq.multi_modal_placeholders
        return {}
768

769
    @property
770
    def mm_processor_kwargs(self) -> dict[str, Any]:
771
772
773
774
775
        if self.first_seq.multi_modal_data:
            return self.first_seq.mm_processor_kwargs
        elif self.encoder_seq is not None:
            return self.encoder_seq.mm_processor_kwargs
        return {}
Woosuk Kwon's avatar
Woosuk Kwon committed
776

777
778
779
780
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

781
782
783
784
785
786
787
788
789
790
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

791
792
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
793
794
        self.state.current_step = 0

795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

820
    def set_last_token_time(self, now: float) -> None:
821
        """Sets the last token time for Request level timings."""
822
823
824
825
826
        # If still in prefill phase, assertion fails.
        assert not self.is_prefill(), (
            "seq_group.set_last_token_time() should not be called "
            "if the seq_group is in prefill phase.")
        self.last_token_latency = now - self.metrics.last_token_time
827
        self.metrics.last_token_time = now
828
829
830
831
832
833
834

    def get_last_token_latency(self) -> float:
        """Returns the latency of the last token."""
        assert not self.is_prefill(), (
            "seq_group.get_last_token_latency() should not be called "
            "if the seq_group is in prefill phase.")
        return self.last_token_latency
835

836
837
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
838
839
840
841
842
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
843
                and self.first_seq.get_output_len() == 1):
844
845
846
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
847
848
        """Sets the first scheduled time and time in queue for Request
        level timings."""
849
850
851
852
853
854
855
856
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

857
858
859
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
860
861
862
        if self.is_single_seq:
            return 0 if self.first_seq.is_finished() else 1
        return self.num_seqs() - self.num_finished_seqs()
863

864
865
866
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
867
    ) -> list[Sequence]:
868
869
        if status is None:
            return self.seqs
870

871
872
873
874
        if self.is_single_seq:
            return self.seqs if self.first_seq.status == status else []

        return [seq for seq in self.seqs if seq.status == status]
875

876
877
878
879
880
881
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

882
    def get_finished_seqs(self) -> list[Sequence]:
883
884
885
886
        if self.is_single_seq:
            return self.seqs if self.first_seq.is_finished() else []

        return [seq for seq in self.seqs if seq.is_finished()]
887

888
889
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
890
891
892
        for seq in self.seqs:
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
893
894

    def get_num_uncomputed_tokens(self) -> int:
895
        num_uncomputed_tokens = 0
896
897
898
        for seq in self.seqs:
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
899
        return num_uncomputed_tokens
900

901
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
902
903
904
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
905
            return len(self.seqs)
906

907
908
909
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

910
        return len(self.get_seqs(status))
911

912
    def num_finished_seqs(self) -> int:
913
914
915
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0
        return len(self.get_finished_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
916

Woosuk Kwon's avatar
Woosuk Kwon committed
917
    def is_finished(self) -> bool:
918
919
920
        if self.is_single_seq:
            return self.first_seq.is_finished()
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
921

922
    def is_prefill(self) -> bool:
923
        return self.first_seq.is_prefill()
924

Woosuk Kwon's avatar
Woosuk Kwon committed
925
    def __repr__(self) -> str:
926
927
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
928
                f"num_seqs={len(self.seqs)})")
929
930


931
932
933
934
935
936
937
938
939
940
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
941
    seq_data_delta: dict[int, SequenceDataDelta]
942
    request_id: str
943
    block_tables: dict[int, list[int]]
944
945
946
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
947
    computed_block_nums: Optional[list[int]] = None
948
949
950
951
952
953
954
955
956
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
957
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
958
959
960
961
962
963
964
965

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
966
967
968
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
969
970
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
971
        lora_request: LoRA request.
972
973
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
974
        state: Internal state tied to this sequence group.
975
        multi_modal_data: Multi modal data.
976
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
977
        encoder_seq_data: Optional sequence data for encoder prompt
978
                          (SequenceGroup.encoder_seq). Should be None
979
980
981
982
983
984
985
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
986
        prompt_adapter_request: Prompt Adapter request.
987
    """
988

989
990
    request_id: str
    is_prompt: bool
991
    seq_data: dict[int, SequenceData]
992
    sampling_params: Optional[SamplingParams]
993
    block_tables: dict[int, list[int]]
994
995
996
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
997
    computed_block_nums: Optional[list[int]] = None
998
999
1000
1001
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
    # "MultiModalDataDict" types. We have to use Any due to msgspec
    # doesn't allow to have union of 2 different dicts.
1002
    token_type_ids: Optional[list[int]] = None
1003
    multi_modal_data: Optional[Any] = None
1004
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
1005
    mm_processor_kwargs: Optional[dict[str, Any]] = None
1006
    encoder_seq_data: Optional[SequenceData] = None
1007
    cross_block_table: Optional[list[int]] = None
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
1023
            else:
1024
                self.token_chunk_size = 1
1025

1026
1027
1028
1029
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

1030
    @property
1031
1032
1033
1034
1035
1036
1037
1038
1039
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

1054
1055
1056
1057
1058
1059
1060
1061
1062
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1063

1064
    def finish_step(self) -> None:
1065
        assert self.state is not None
1066
1067
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1068
1069
        self.state.current_step += 1

1070

1071
1072
1073
1074
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1075
1076
1077
1078
1079
1080
1081
1082
1083
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1084
1085
    parent_seq_id: int
    output_token: int
1086
    logprobs: dict[int, Logprob]
1087
1088

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
1089
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1090
1091
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1092

1093
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1094
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1095
            raise NotImplementedError()
1096
1097
1098
1099
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1100
1101


1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1114
1115
1116
1117
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1118
    """The model output associated with a completion sequence group."""
1119
    __metaclass__ = SequenceGroupOutput
1120
    samples: list[SequenceOutput]
1121
1122
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1123
    step_index: Optional[int] = 0
1124
1125

    def __repr__(self) -> str:
1126
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1127
1128
                f"prompt_logprobs={self.prompt_logprobs})")

1129
    def __eq__(self, other: object) -> bool:
1130
        if not isinstance(other, CompletionSequenceGroupOutput):
1131
1132
1133
1134
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1135

1136
class PoolingSequenceGroupOutput(
1137
1138
1139
1140
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1141
    """The model output associated with a pooling sequence group."""
1142
    __metaclass__ = SequenceGroupOutput
1143
1144
1145
    # Annotated as Any to be compatible with msgspec
    # The actual type is in SequenceGroup.pooled_data
    data: Any
1146
1147

    def __repr__(self) -> str:
1148
        return f"PoolingSequenceGroupOutput(data={self.data}"
1149
1150

    def __eq__(self, other: object) -> bool:
1151
        if not isinstance(other, PoolingSequenceGroupOutput):
1152
            raise NotImplementedError()
1153
        return self.data == other.data
1154
1155


1156
1157
1158
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1159
1160
1161
1162
1163
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

1164
    tensors: dict[str, torch.Tensor]
1165

1166
1167
1168
1169
1170
1171
1172
    def __init__(self, tensors):
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors

1173
1174
1175
1176
1177
1178
    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

1179
    def __setitem__(self, key: str, value: torch.Tensor):
1180
1181
        self.tensors[key] = value

1182
1183
1184
    def items(self):
        return self.tensors.items()

1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1195
1196
1197
1198
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1199
    """The output from a pooling operation in the pooling model."""
1200
    outputs: list[PoolingSequenceGroupOutput]
1201

1202
    def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
1203
1204
        return self.outputs[idx]

1205
    def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1216
def get_all_seq_ids(
1217
        seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
1218
1219
1220
1221
1222
1223
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1224
def get_all_seq_ids_and_request_ids(
1225
1226
    seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
1227
1228
1229
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
1230
1231
    seq_ids: list[int] = []
    request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
1232
1233
1234
1235
1236
1237
1238
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1239
1240
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1241
1242
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1243
    the target model to the proposer model.
1244
1245
1246

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1247
1248
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1249
    hidden_states: torch.Tensor
1250
    # The sequence group metadata list. Only needed for decode step.
1251
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1252
1253
1254
1255
1256
1257
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1258
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1259
1260

    def __post_init__(self):
1261
1262
1263
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1264
1265

    @property
1266
    def seq_ids(self) -> list[int]:
1267
        return self._seq_ids
1268

1269
1270
    def update(self,
               hidden_states: torch.Tensor,
1271
               seq_group_metadata_list: list[SequenceGroupMetadata],
1272
1273
1274
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1275
        assert len(seq_group_metadata_list) == len(hidden_states)
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
        # self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        # self.hidden_states = torch.cat([self.hidden_states, hidden_states])

        # if self.second_last_token_hidden_states is not None:
        #     # Adding dummy hidden_states to this to maintain same shape
        #     self.second_last_token_hidden_states = torch.cat([
        #         self.second_last_token_hidden_states,
        #         torch.zeros_like(hidden_states)
        #         if second_last_token_hidden_states is None else
        #         second_last_token_hidden_states
        #     ])
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        diff_seq_ids = [item for item in self._seq_ids if item not in seq_ids]
        index = [self._seq_ids.index(seq_id) for seq_id in diff_seq_ids]
        self._seq_ids = diff_seq_ids
        self.hidden_states = self.hidden_states[index]
1292
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])
1293
        
1294
1295
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
1296
            self.second_last_token_hidden_states = self.second_last_token_hidden_states[index]
1297
1298
1299
1300
1301
1302
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])
1303
1304
        self._seq_ids.extend(seq_ids)
        
1305

1306
    def prune(self,
1307
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1308
1309
1310
1311
1312
1313
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1314
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1315
        if seq_ids != self._seq_ids:
1316
            # Batch contents changed - prune removed sequences.
1317
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1318
            self.hidden_states = self.hidden_states[index]
1319
1320
1321
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1322
            self._seq_ids = seq_ids
1323

1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1342

1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
class Logits(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
    """Logits corresponding to in-progress sequences.
    Used in speculative decoding to pass lm_head logits from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the logits tensor"""
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
    logits: torch.Tensor
    # The sequence group metadata list. Only needed for decode step.
zhuwenwen's avatar
zhuwenwen committed
1355
    seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
1356

zhuwenwen's avatar
zhuwenwen committed
1357
    _seq_ids: list[int] = msgspec.field(default_factory=list)
1358
1359
1360
1361
1362
1363
1364

    def __post_init__(self):
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.logits)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)

    @property
zhuwenwen's avatar
zhuwenwen committed
1365
    def seq_ids(self) -> list[int]:
1366
1367
1368
1369
        return self._seq_ids
    
    def update(self,
               logits: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
1370
               seq_group_metadata_list: list[SequenceGroupMetadata]):
1371
1372
1373
1374
1375
1376
1377
        """Update hidden states from target model invocation. Only used for
        decode steps"""
        assert len(seq_group_metadata_list) == len(logits)
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.logits = torch.cat([self.logits, logits])

    def prune(self,
zhuwenwen's avatar
zhuwenwen committed
1378
              seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self._seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
            self.logits = self.logits[index]
            self._seq_ids = seq_ids


1393
1394
1395
1396
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1397
1398
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1399
    # The sequence group metadata list.
1400
    seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1401
                                        SequenceGroupMetadataDelta]]
1402
    # Blocks to swap in. List of CPU -> GPU block number.
1403
    blocks_to_swap_in: list[tuple[int,
1404
                                  int]] = msgspec.field(default_factory=list)
1405
    # Blocks to swap out. List of GPU -> CPU block number.
1406
    blocks_to_swap_out: list[tuple[int,
1407
                                   int]] = msgspec.field(default_factory=list)
1408
    # Blocks to copy. Source to dest block.
1409
    blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
1410
1411
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1412
1413
1414
1415
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1416
1417
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1418
1419
    # Optional logits from prior step.
    previous_logits: Optional[Logits] = None
1420
1421
    # The number of forward steps to run.
    num_steps: int = 1
1422
1423
    # The step index for spec model input.
    spec_step_idx: Optional[int] = None
Mor Zusman's avatar
Mor Zusman committed
1424
    # Finished request ids since last step.
1425
    finished_requests_ids: list[str] = msgspec.field(default_factory=list)
1426
1427
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1428
1429
    # Async callback
    async_callback: Optional[Callable] = None
1430

1431
1432
1433
1434
1435
1436
    # Optional tree attention mask from draft model.
    tree_attn_masks: Optional[torch.Tensor] = None

    # Optional tree position ids from draft model.
    tree_position_ids: Optional[torch.Tensor] = None

1437
1438
1439
    # Optional slot mapping of kvcache that pending to be moved generated from draft model.
    kvcache_slot_to_be_moved: Optional[torch.Tensor] = None

1440
1441
1442
1443
1444
1445
    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1446
        assert first_seq_group.state is not None
1447
1448
1449
1450
1451
1452
1453
1454
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1455
        assert first_seq_group.state is not None
1456
        return first_seq_group.state.remaining_steps == 1
1457
1458
1459
1460
1461
1462

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1463
1464
1465
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1466
1467

    def clone(
1468
        self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
1469
                                                  SequenceGroupMetadataDelta]]
1470
1471
1472
1473
1474
1475
1476
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1477
            virtual_engine=self.virtual_engine,
1478
1479
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1480
            previous_hidden_states=self.previous_hidden_states,
1481
            previous_logits=self.previous_logits,
1482
            num_steps=self.num_steps,
1483
1484
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1485
            if self.last_sampled_token_ids is not None else None,
1486
1487
            async_callback=self.async_callback,
            tree_attn_masks=self.tree_attn_masks,
1488
1489
            tree_position_ids=self.tree_position_ids,
            kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)
1490
1491
1492
1493
1494
1495
1496
1497
1498


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
1499
    seq_id_to_index: dict[str, int] = field(default_factory=dict)
1500
1501

    # seq ids to be finished
1502
    to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
1503
1504

    # seq id to finished sequences
1505
    finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1542
1543
1544
1545
            params = copy.deepcopy(original_params)
            params.n = 1
            if params.seed is not None:
                params.seed += i
1546
            seq_group = engine._add_processed_request(
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            pooling_params=seq_group.pooling_params,
1566
            pooled_data=seq_group.pooled_data,
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            prompt_adapter_request=seq_group.prompt_adapter_request,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
1580
1581
        # for the first remaining sequence, and then return None for the
        # rest of sequences
1582
        if self.streaming:
1583
1584
            first_remaining_id = next(iter(self.to_be_finished))
            if seq_group.request_id == first_remaining_id:
1585
1586
1587
1588
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
1589
        # when the last sequences finishes, and then return None for the
1590
        # rest of the time
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
        if (len(self.to_be_finished) == 1
                and seq_group.request_id in self.to_be_finished
                and seq_group.is_finished()):
            assert self.assembled_seq_group is not None
            params = self.assembled_seq_group.sampling_params
            assert isinstance(params, SamplingParams)
            if not self.output_produced:
                self.output_produced = True
                if params._real_n is not None:
                    # Get the top-n sequences.
                    n = params._real_n or params.n
                    seqs = self.assembled_seq_group.seqs
                    sorting_key = lambda seq: seq.get_cumulative_logprob()
                    sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                    top_n_seqs = sorted_seqs[:n]
                    self.assembled_seq_group.seqs = top_n_seqs
                return self.assembled_seq_group
            if self.output_produced:
                return None
        return None