sequence.py 54.6 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from abc import ABC, abstractmethod
5
from array import array
6
from collections import defaultdict
7
from dataclasses import dataclass, field
8
from functools import cached_property, reduce
9
10
from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List,
                    Mapping, Optional)
11
from typing import Sequence as GenericSequence
12
from typing import Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
import msgspec
15
import torch
16
from typing_extensions import assert_never
17

18
from vllm.lora.request import LoRARequest
19
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
20
from vllm.pooling_params import PoolingParams
21
from vllm.prompt_adapter.request import PromptAdapterRequest
22
from vllm.sampling_params import RequestOutputKind, SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
if TYPE_CHECKING:
25
    from vllm.inputs import SingletonInputs
26

27
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
28

29
30
VLLM_INVALID_TOKEN_ID = -1

31

32
33
34
35
36
def array_full(token_id: int, count: int):
    """:class:`array` equivalent of :func:`numpy.full`."""
    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count


37
38
39
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
40
41
@dataclass
class Logprob:
42
43
44
45
46
47
48
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
49
    logprob: float
50
    rank: Optional[int] = None
51
52
53
    decoded_token: Optional[str] = None


54
55
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
56
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
57
# {token_id -> logprob} for each sequence group.
58
SampleLogprobs = List[Dict[int, Logprob]]
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
class SequenceStatus(enum.IntEnum):
62
    """Status of a sequence."""
63
64
65
66
67
68
69
70
71
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
75
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
76
77
78
79
80
81
82

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
83
84
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
85
        elif status == SequenceStatus.FINISHED_IGNORED:
86
87
88
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
89
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
90
91
92
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
93

94

95
96
97
98
99
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


100
101
102
103
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

104
    Attributes:
105
106
107
108
109
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
110
111
112
113
114
115
116
        scheduler_time: The time spent in the scheduler when this request was
                        being considered by the scheduler.
        model_forward_time: The time spent in the model forward pass when this
                            request was in the batch.
        model_execute_time: The time spent in the model execute function. This
                            will include model forward, block/sync across
                            workers, cpu-gpu sync time and sampling time.
117
118
119
120
121
122
123
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None
124
125
126
    scheduler_time: Optional[float] = None
    model_forward_time: Optional[float] = None
    model_execute_time: Optional[float] = None
127
128


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class SequenceDataDelta(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta SequenceData to send to workers per step."""
    # A new token to be appended to existing SequenceData.
    new_output_token_ids: List[int]
    # Overwriting existing `cumulative_logprob`
    new_cumulative_logprob: float
    # Overwriting existing `num_computed_tokens`.
    new_num_computed_tokens: int
    # Overwriting existing `stage`.
    new_stage: SequenceStage


class SequenceData(msgspec.Struct,
                   omit_defaults=True):  # type: ignore[call-arg]
146
147
148
149
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
150
151
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
152
153
154
155
156
157

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
158
159
160
161
162
163
164
165
166
167
168
169
    # NOTE: we cannot use Union[List, array] because msgspec cannot support
    # union of 2 list types.
    _prompt_token_ids: array
    _output_token_ids: array = msgspec.field(
        default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

    ### The below fields should not be passed as an argument ###
    _cumulative_logprob: float = 0.0
    _prompt_token_ids_tuple: Tuple[int,
                                   ...] = msgspec.field(default_factory=tuple)
    # The number of tokens that are computed (that run against the model).
    _num_computed_tokens: int = 0
170
171
    # The number of tokens with prefix cache hit.
    _num_cached_tokens: int = 0
172
173
174
175
176
177
178
    _stage: SequenceStage = SequenceStage.PREFILL
    _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)

    # It is used to get delta input. It is reset when `get_delta_and_reset`
    # is called.
    _new_appended_tokens: List[int] = msgspec.field(default_factory=list)

179
180
181
    # It is used to compute mrope_position_ids.
    _mrope_position_delta: Optional[int] = None

182
    @staticmethod
183
184
185
186
187
188
189
190
191
    def from_prompt_token_counts(
            *token_counts: Tuple[int, int]) -> "SequenceData":
        """
        Construct a :class:`SequenceData` instance by concatenating
        prompt token sequences.

        Each tuple represents one token sequence, expressed in the form
        :code:`(token_id, count)`.
        """
192
        if len(token_counts) == 0:
193
194
            return SequenceData.from_seqs([])

195
196
197
198
        prompt_token_ids_arr = reduce(
            array.__iadd__,
            (array_full(token_id, count) for token_id, count in token_counts),
        )
199

200
        return SequenceData(prompt_token_ids_arr)
201
202
203
204
205
206

    @staticmethod
    def from_seqs(
        prompt_token_ids: GenericSequence[int],
        output_token_ids: Optional[GenericSequence[int]] = None,
    ) -> "SequenceData":
207
208
209
210
        """
        Construct a :class:`SequenceData` instance from prompt and output
        token sequences.
        """
211
212
213
214
215
216
217
218
219
220
221
222
        prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     prompt_token_ids)

        if output_token_ids is None:
            return SequenceData(prompt_token_ids_arr)

        output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                     output_token_ids)

        return SequenceData(prompt_token_ids_arr,
                            _output_token_ids=output_token_ids_arr)

223
224
225
226
227
    def __post_init__(self) -> None:
        assert self._prompt_token_ids.typecode == "l"
        assert self._output_token_ids.typecode == "l"
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
            self._prompt_token_ids)
228
229
230
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
231
232
        assert isinstance(self._prompt_token_ids, array)
        assert isinstance(self._output_token_ids, array)
233
234
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
235

236
237
238
239
    @property
    def cumulative_logprob(self) -> float:
        return self._cumulative_logprob

240
241
242
243
244
245
    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
246
        raise NotImplementedError
247

248
249
    @property
    def prompt_token_ids_array(self) -> array:
250
251
252
253
254
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
255
256
        return self._prompt_token_ids

257
258
259
260
261
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
262
263
    def output_token_ids(self,
                         new_output_token_ids: GenericSequence[int]) -> None:
264
265
        self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                       new_output_token_ids)
266
267
        self._update_cached_all_tokens()

268
269
    @property
    def output_token_ids_array(self) -> array:
270
271
272
273
274
275
        """Return the prompt token ids in array type.

        Note that the array is in "I" type, and it is not compatible
        with torch.long (2 bytes vs 4 bytes). So beware of the usage.
        """
        assert isinstance(self._output_token_ids, array)
276
277
        return self._output_token_ids

278
279
280
281
282
283
284
285
    @property
    def mrope_position_delta(self) -> Optional[int]:
        return self._mrope_position_delta

    @mrope_position_delta.setter
    def mrope_position_delta(self, new_mrope_position_delta):
        self._mrope_position_delta = new_mrope_position_delta

286
    def append_token_id(self, token_id: int, logprob: float) -> None:
287
        self._output_token_ids.append(token_id)
288
        self._new_appended_tokens.append(token_id)
289
        self._cached_all_token_ids.append(token_id)
290
        self._cumulative_logprob += logprob
291
292

    def get_len(self) -> int:
293
        return len(self._output_token_ids) + len(self._prompt_token_ids)
294

295
    def get_prompt_len(self) -> int:
296
        return len(self._prompt_token_ids)
297

298
    def get_output_len(self) -> int:
299
        return len(self._output_token_ids)
300

301
    def get_token_ids(self) -> List[int]:
302
        return self._cached_all_token_ids
303

304
305
306
307
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
308
        prompt_length = self.get_prompt_len()
309
310
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
311
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
312
313
314
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

315
316
317
318
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

319
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
320
321
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
322
323
324
325
326
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
327

328
329
330
331
332
333
334
335
    def get_num_cached_tokens(self) -> int:
        """Return the number of tokens with prefix cache hit."""
        return self._num_cached_tokens

    def update_num_cached_tokens(self, num_cached_tokens: int):
        """Update the number of tokens with prefix cache hit."""
        self._num_cached_tokens = num_cached_tokens

336
    def reset_state_for_recompute(self) -> None:
337
338
339
340
341
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
342
        self._stage = SequenceStage.PREFILL
343
        self._new_appended_tokens = []
344
345

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
346
        """Return the number of prefill tokens that are not computed."""
347
348
349
350
351
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

352
    def get_last_token_id(self) -> int:
353
354
355
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
356

357
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
358
359
        return self.prompt_token_ids

360
    def get_output_token_ids(self) -> Tuple[int, ...]:
361
362
        return self.output_token_ids

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    def get_delta_and_reset(self) -> SequenceDataDelta:
        delta = SequenceDataDelta(self._new_appended_tokens,
                                  self._cumulative_logprob,
                                  self.get_num_computed_tokens(), self.stage)
        # Reset delta state.
        self._new_appended_tokens = []
        return delta

    def apply_delta(self, delta: SequenceDataDelta):
        self._num_computed_tokens = delta.new_num_computed_tokens
        self._cumulative_logprob = delta.new_cumulative_logprob
        self._stage = delta.new_stage
        self._output_token_ids.extend(delta.new_output_token_ids)
        self._cached_all_token_ids.extend(delta.new_output_token_ids)

378
379
380
381
    @property
    def stage(self) -> SequenceStage:
        return self._stage

382
383
    def __repr__(self) -> str:
        return (f"SequenceData("
384
                f"prompt_token_ids={self._prompt_token_ids}, "
385
386
387
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob}, "
                f"get_num_computed_tokens={self.get_num_computed_tokens()}")
388
389


Woosuk Kwon's avatar
Woosuk Kwon committed
390
class Sequence:
391
    """Stores the data, status, and block information of a sequence.
392

393
394
395
    The sequence is constructed from the :data:`DecoderOnlyInputs`
    (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
    instance passed in through the :code:`inputs` constructor argument.
396

397
398
    Args:
        seq_id: The ID of the sequence.
399
        inputs: The inputs of the sequence.
400
401
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
402
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
403
        lora_request: LoRA request.
404
        prompt_adapter_request: Prompt Adapter request.
405
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
406
407

    def __init__(
408
409
        self,
        seq_id: int,
410
        inputs: "SingletonInputs",
411
412
413
414
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
    ) -> None:
        self.seq_id = seq_id
417
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
418
        self.block_size = block_size
419
        self.eos_token_id = eos_token_id
420
        self.lora_request = lora_request
421
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
422

423
        self.data = SequenceData.from_seqs(self.prompt_token_ids)
424
        self.output_logprobs: SampleLogprobs = []
425
        self.output_text = ""
426

427
        self.status = SequenceStatus.WAITING
428
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
429

430
        # These are used to keep track of delta outputs
431
        self._last_output_token_ids_offset: int = 0
432
433
        self._last_output_text_offset: int = 0

434
435
436
437
438
439
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

440
441
    @property
    def n_blocks(self) -> int:
442
        return (self.get_len() + self.block_size - 1) // self.block_size
443

444
    @cached_property
445
    def prompt(self) -> Optional[str]:
446
        inputs = self.inputs
447

448
449
450
451
        if inputs["type"] == "token":
            return inputs.get("prompt")

        assert_never(inputs)
452

453
    @cached_property
454
    def prompt_token_ids(self) -> List[int]:
455
        inputs = self.inputs
456

457
458
        if inputs["type"] == "token":
            return inputs.get("prompt_token_ids", [])
459

460
461
462
463
        assert_never(inputs)

    @cached_property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
464
465
        inputs = self.inputs

466
467
        if inputs["type"] == "token":
            return None
468

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        assert_never(inputs)

    @cached_property
    def multi_modal_data(self) -> "MultiModalDataDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_data", {})

        assert_never(inputs)

    @cached_property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("mm_processor_kwargs", {})

        assert_never(inputs)
488

489
490
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
491
        inputs = self.inputs
492

493
494
495
496
        if inputs["type"] == "token":
            return inputs.get("multi_modal_placeholders", {})

        assert_never(inputs)
497

498
499
500
501
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

502
503
504
505
506
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

507
508
509
510
511
    def get_output_text_to_return(self, buffer_length: int,
                                  delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

512
513
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
514
515
516
        if not delta:
            return self.output_text[:-buffer_length] if truncate else (
                self.output_text)
517
518
519
        length = len(self.output_text)
        if truncate:
            length -= buffer_length
520
521
522
523
524
525
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""

526
527
    def get_output_token_ids_to_return(
            self, delta: bool) -> Union[GenericSequence[int], int]:
528
529
530
531
        """If delta is True, only new tokens since the last call to
        this method are returned"""
        if not delta:
            return self.get_output_token_ids()
532
533
534
535
536
537
538
539
540
541
542
543
544

        output_len = self.get_output_len()

        # Get the number of new tokens
        num_new_tokens = output_len - self._last_output_token_ids_offset
        self._last_output_token_ids_offset = output_len

        # Return new tokens
        if num_new_tokens == 1:
            # Optimization for single decode token case
            # (which is what we have most of the time)
            return self.data._cached_all_token_ids[-1]

545
546
547
        if num_new_tokens == 0:
            return []

548
        return self.data._cached_all_token_ids[-num_new_tokens:]
549

550
    def hash_of_block(self, logical_idx: int) -> int:
551
552
        # TODO This can produce incorrect hash when block size > prompt size

553
        # Compute the number of tokens in the sequence
554
555
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
556
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
557
558
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
559
560
561
562

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

563
564
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
565
        self.data.reset_state_for_recompute()
566

567
568
    def append_token_id(self, token_id: int, logprobs: Dict[int,
                                                            Logprob]) -> None:
569
570
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
571
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
572

Woosuk Kwon's avatar
Woosuk Kwon committed
573
    def get_len(self) -> int:
574
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
575

576
577
578
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

579
580
581
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
582
    def get_token_ids(self) -> List[int]:
583
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
584

585
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
586
587
        return self.data.get_prompt_token_ids()

588
    def get_last_token_id(self) -> int:
589
        return self.data.get_last_token_id()
590

591
592
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
593
594
595
596

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

597
598
599
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

600
601
602
603
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
604

605
606
607
608
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
609
610
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
611
612
613
614
615
616
617
618
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
619
    def __repr__(self) -> str:
620
621
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
622
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
623

Woosuk Kwon's avatar
Woosuk Kwon committed
624

625
626
class SequenceGroupState(msgspec.Struct,
                         omit_defaults=True):  # type: ignore[call-arg]
627
628
629
630
631
632
633
634
635
636
637
    """Mutable state tied to a specific sequence group"""

    # for multi-step decoding
    num_steps: int = 1
    current_step: int = 0

    @property
    def remaining_steps(self) -> int:
        return self.num_steps - self.current_step


Woosuk Kwon's avatar
Woosuk Kwon committed
638
class SequenceGroup:
639
640
641
642
643
644
645
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
646
        lora_request: LoRA request.
647
648
649
650
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
651
652
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
653
        trace_headers: OpenTelemetry trace headers.
654
        prompt_adapter_request: Prompt Adapter request.
655
        priority: User-defined priority of the request.
656
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
657
658
659

    def __init__(
        self,
660
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
661
        seqs: List[Sequence],
662
        arrival_time: float,
663
        sampling_params: Optional[SamplingParams] = None,
664
        lora_request: Optional[LoRARequest] = None,
665
666
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
667
        encoder_seq: Optional[Sequence] = None,
668
        trace_headers: Optional[Mapping[str, str]] = None,
669
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
670
        priority: int = 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
671
    ) -> None:
672
        self.request_id = request_id
673
        self.seqs = seqs
674
        self.first_seq = seqs[0]
675
        self.arrival_time = arrival_time
676
        self.is_single_seq = len(seqs) == 1
677
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
678

679
        self.sampling_params = sampling_params
680
681
682
683
684
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
685
        self.lora_request = lora_request
686
        self.prompt_logprobs: Optional[PromptLogprobs] = None
687
        self.state = SequenceGroupState()
688
689
        self.embeddings = embeddings
        self.pooling_params = pooling_params
690
        self.prompt_adapter_request = prompt_adapter_request
691
        self.encoder_seq = encoder_seq
692
        self.trace_headers = trace_headers
693
        self.priority = priority
694

695
696
        self.cached_request_output = None

697
    @property
698
    def prompt(self) -> Optional[str]:
699
        return self.first_seq.prompt
700
701
702

    @property
    def prompt_token_ids(self) -> List[int]:
703
        return self.first_seq.prompt_token_ids
704

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

721
    @property
722
    def multi_modal_data(self) -> MultiModalDataDict:
723
        return self.first_seq.multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
724

725
726
727
728
    @property
    def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
        return self.first_seq.multi_modal_placeholders

729
730
    @property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
731
        return self.first_seq.mm_processor_kwargs
732

733
734
735
736
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

737
738
739
740
741
742
743
744
745
746
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

747
748
    def init_multi_step(self, num_steps: int) -> None:
        self.state.num_steps = num_steps
749
750
        self.state.current_step = 0

751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
    def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
                                             num_scheduler_steps: int,
                                             is_multi_step: bool,
                                             enable_chunking: bool) -> None:

        if not is_multi_step:
            self.init_multi_step(num_steps=num_scheduler_steps)
            return

        # Multi-Step case
        is_prefill = self.is_prefill()

        # The asserts below reflect the expectations of the current system.
        if is_prefill and enable_chunking:
            assert num_lookahead_slots == num_scheduler_steps
            self.init_multi_step(num_steps=num_lookahead_slots)
        else:
            is_decode: bool = not is_prefill
            # If it is a prefill, num_lookahead_slots must be 0
            assert num_lookahead_slots == 0 or is_decode
            # If it is a decode, num_lookahead_slots + 1 must match
            # the scheduler steps.
            assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
            self.init_multi_step(num_steps=num_lookahead_slots + 1)

776
    def get_last_latency(self, now: float) -> float:
777
778
779
780
781
782
783
784
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
785
786
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
787
788
        return latency

789
790
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
791
792
793
794
795
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
796
                and self.first_seq.get_output_len() == 1):
797
798
799
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
800
801
        """Sets the first scheduled time and time in queue for Request
        level timings."""
802
803
804
805
806
807
808
809
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

810
811
812
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
813
        return 0 if self.first_seq.is_finished() else 1
814

815
816
817
818
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
819
820
        if status is None:
            return self.seqs
821

822
        return self.seqs if self.first_seq.status == status else []
823

824
825
826
827
828
829
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

830
    def get_finished_seqs(self) -> List[Sequence]:
831
        return self.seqs if self.first_seq.is_finished() else []
832

833
834
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
835
836
837
        seq = self.first_seq
        if not seq.is_finished():
            seq.data.update_num_computed_tokens(num_new_computed_tokens)
838
839

    def get_num_uncomputed_tokens(self) -> int:
840
        num_uncomputed_tokens = 0
841
842
843
        seq = self.first_seq
        if not seq.is_finished():
            num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
844
        return num_uncomputed_tokens
845

846
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
847
848
849
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
850
            return len(self.seqs)
851

852
853
854
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

855
        return len(self.get_seqs(status))
856

857
    def num_finished_seqs(self) -> int:
858
        return 1 if self.first_seq.is_finished() else 0
Woosuk Kwon's avatar
Woosuk Kwon committed
859

Woosuk Kwon's avatar
Woosuk Kwon committed
860
    def is_finished(self) -> bool:
861
        return self.first_seq.is_finished()
Woosuk Kwon's avatar
Woosuk Kwon committed
862

863
    def is_prefill(self) -> bool:
864
        return self.first_seq.is_prefill()
865

Woosuk Kwon's avatar
Woosuk Kwon committed
866
    def __repr__(self) -> str:
867
868
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
869
                f"num_seqs={len(self.seqs)})")
870
871


872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
class SequenceGroupMetadataDelta(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
    """Delta of SequenceGroupMetadata.

    After sending the first SequenceGroupMetadata, vLLM scheduler
    only sends delta to reduce the data payload size.
    """
    seq_data_delta: Dict[int, SequenceDataDelta]
    request_id: str
    block_tables: Dict[int, List[int]]
    is_prompt: bool
    do_sample: bool = True
    token_chunk_size: Optional[int] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())


class SequenceGroupMetadata(
        msgspec.Struct,
        tag=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
898
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
899
900
901
902
903
904
905
906

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
907
908
909
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
910
911
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
912
        lora_request: LoRA request.
913
914
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
915
        state: Internal state tied to this sequence group.
916
        multi_modal_data: Multi modal data.
917
        mm_processor_kwargs: Multimodal input processor / mapper overrides.
918
        encoder_seq_data: Optional sequence data for encoder prompt
919
                          (SequenceGroup.encoder_seq). Should be None
920
921
922
923
924
925
926
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
927
        prompt_adapter_request: Prompt Adapter request.
928
    """
929

930
931
932
    request_id: str
    is_prompt: bool
    seq_data: Dict[int, SequenceData]
933
    sampling_params: Optional[SamplingParams]
934
935
936
937
938
939
940
941
942
943
    block_tables: Dict[int, List[int]]
    do_sample: bool = True
    pooling_params: Optional[PoolingParams] = None
    lora_request: Optional[LoRARequest] = None
    computed_block_nums: Optional[List[int]] = None
    state: Optional[SequenceGroupState] = msgspec.field(
        default_factory=lambda: SequenceGroupState())
    # "MultiModalDataDict" types. We have to use Any due to msgspec
    # doesn't allow to have union of 2 different dicts.
    multi_modal_data: Optional[Any] = None
944
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
945
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
    encoder_seq_data: Optional[SequenceData] = None
    cross_block_table: Optional[List[int]] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
    token_chunk_size: Optional[int] = None

    ### Stateful fields that are lazily defined. ###
    # The number of speculative tokens adopted in this request.
    # None means specuative decoding is not used.
    # Zero means speculative decoding is disabled for some reasons.
    # TODO: We should maintain this states out of the sequence group.
    num_speculative_tokens: Optional[int] = None

    def __post_init__(self):
        if self.seq_data is not None and self.token_chunk_size is None:
            if self.is_prompt:
                self.token_chunk_size = next(iter(
                    self.seq_data.values())).get_len()
963
            else:
964
                self.token_chunk_size = 1
965

966
967
968
969
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

970
    @property
971
972
973
974
975
976
977
978
979
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

980
981
982
983
984
985
986
987
988
989
990
991
992
993
    # Multi-Step Chunked-Prefill property
    @property
    def is_single_step_prompt(self) -> bool:
        # do_sample is true, only when the token_chunk_size matches the
        # num_uncomputed_tokens of the sequence. This indicates that
        # the prompt will finish processing in a single `execute_model`
        # step.
        return self.is_prompt and self.do_sample

    def get_first_seq_id(self) -> int:
        # This is an efficient way of fetching the seq_id when
        # we know this SequenceGroup has only one sequence.
        return next(iter(self.seq_data))

994
995
996
997
998
999
1000
1001
1002
    def apply_delta(self,
                    sequence_group_metadata_delta: SequenceGroupMetadataDelta):
        for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
            self.seq_data[id].apply_delta(delta)
        assert self.request_id == sequence_group_metadata_delta.request_id
        self.block_tables = sequence_group_metadata_delta.block_tables
        self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
        self.do_sample = sequence_group_metadata_delta.do_sample
        self.is_prompt = sequence_group_metadata_delta.is_prompt
1003

1004
    def finish_step(self) -> None:
1005
        assert self.state is not None
1006
1007
        assert self.state.current_step < self.state.num_steps, \
            f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa
1008
1009
        self.state.current_step += 1

1010

1011
1012
1013
1014
class SequenceOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1015
1016
1017
1018
1019
1020
1021
1022
1023
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
1024
1025
1026
    parent_seq_id: int
    output_token: int
    logprobs: Dict[int, Logprob]
1027
1028

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
1029
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
1030
1031
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
1032

1033
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
1034
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
1035
            raise NotImplementedError()
1036
1037
1038
1039
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
1040
1041


1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


1054
1055
1056
1057
1058
class CompletionSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
    __metaclass__ = SequenceGroupOutput
1059
    """The model output associated with a completion sequence group."""
1060
1061
1062
    samples: List[SequenceOutput]
    # Prompt logprob for each prompt query token.
    prompt_logprobs: Optional[PromptLogprobs]
1063
1064

    def __repr__(self) -> str:
1065
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
1066
1067
                f"prompt_logprobs={self.prompt_logprobs})")

1068
    def __eq__(self, other: object) -> bool:
1069
        if not isinstance(other, CompletionSequenceGroupOutput):
1070
1071
1072
1073
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

1074

1075
1076
1077
1078
1079
class EmbeddingSequenceGroupOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True,  # type: ignore[call-arg]
):
1080
    """The model output associated with an embedding sequence group."""
1081
1082
    __metaclass__ = SequenceGroupOutput
    embeddings: List[int]
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


1094
1095
1096
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


1123
1124
1125
1126
class PoolerOutput(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
1127
1128
1129
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

1130
1131
    # lazy import to avoid circular import
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
1132
    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
1133

1134
    def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1148
1149
1150
1151
1152
1153
1154
1155
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1156
1157
1158
1159
1160
1161
1162
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
1163
    request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set)
1164
1165
1166
1167
1168
1169
1170
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1171
1172
class HiddenStates(msgspec.Struct, array_like=True,
                   omit_defaults=True):  # type: ignore[call-arg]
1173
1174
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
1175
    the target model to the proposer model.
1176
1177
1178

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""
1179
1180
    # Scorer hidden states. For prefill step, it is used for hidden states of
    # all tokens, whereas for decode step, it use used for last accepted tokens.
1181
    hidden_states: torch.Tensor
1182
1183
1184
1185
1186
1187
1188
1189
    # The sequence group metadata list. Only needed for decode step.
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    # Scorer hidden states of the 2nd last token proposed by the proposer (
    # irrespective of whether it was accepted or not). Only used for cases when
    # last proposed token is accepted (i.e., in case of bonus tokens). For the
    # case of no bonus tokens, these are ignored.
    second_last_token_hidden_states: Optional[torch.Tensor] = None

1190
1191
1192
    _seq_ids: List[int] = msgspec.field(default_factory=list)

    def __post_init__(self):
1193
1194
1195
        if self.seq_group_metadata_list is not None:
            assert len(self.seq_group_metadata_list) == len(self.hidden_states)
            self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
1196
1197
1198
1199

    @property
    def seq_ids(self) -> List[int]:
        return self._seq_ids
1200

1201
1202
1203
1204
1205
1206
    def update(self,
               hidden_states: torch.Tensor,
               seq_group_metadata_list: List[SequenceGroupMetadata],
               second_last_token_hidden_states: Optional[torch.Tensor] = None):
        """Update hidden states from target model invocation. Only used for
        decode steps"""
1207
        assert len(seq_group_metadata_list) == len(hidden_states)
1208
        self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1209
1210
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

1211
1212
1213
1214
1215
1216
1217
1218
1219
        if self.second_last_token_hidden_states is not None:
            # Adding dummy hidden_states to this to maintain same shape
            self.second_last_token_hidden_states = torch.cat([
                self.second_last_token_hidden_states,
                torch.zeros_like(hidden_states)
                if second_last_token_hidden_states is None else
                second_last_token_hidden_states
            ])

1220
1221
    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
1222
1223
1224
1225
1226
1227
        """Prune to provided list of sequence ids. Only used for decode steps.
        """
        # Currently this prunes all seq_ids not present in
        # seq_group_metadata_list which might cause problems where a sequence
        # may be "paused" then "resumed" later. This should only prune sequences
        # which are confirmed to be aborted.
1228
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
1229
        if seq_ids != self._seq_ids:
1230
            # Batch contents changed - prune removed sequences.
1231
            index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1232
            self.hidden_states = self.hidden_states[index]
1233
1234
1235
            if self.second_last_token_hidden_states is not None:
                self.second_last_token_hidden_states = self\
                    .second_last_token_hidden_states[index]
1236
            self._seq_ids = seq_ids
1237

1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
    def expand_with_bonus_tokens(
            self, seq_with_bonus_token_in_last_step: set) -> None:
        """Expand hidden states for sequences with bonus tokens. This is in
        alignment with `MultiStepWorker._expand_execute_model_request`."""
        if self.second_last_token_hidden_states is None \
            or not seq_with_bonus_token_in_last_step:
            return

        index = []
        for seq_id in self._seq_ids:
            i = self._seq_ids.index(seq_id)
            if seq_id in seq_with_bonus_token_in_last_step:
                index.append(i + len(self._seq_ids))
            index.append(i)

        self.hidden_states = torch.cat(
            [self.hidden_states, self.second_last_token_hidden_states])[index]

1256

1257
1258
1259
1260
class ExecuteModelRequest(
        msgspec.Struct,
        array_like=True,  # type: ignore[call-arg]
        omit_defaults=True):  # type: ignore[call-arg]
1261
1262
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1263
    # The sequence group metadata list.
1264
1265
    seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta]]
1266
    # Blocks to swap in. List of CPU -> GPU block number.
1267
1268
    blocks_to_swap_in: List[Tuple[int,
                                  int]] = msgspec.field(default_factory=list)
1269
    # Blocks to swap out. List of GPU -> CPU block number.
1270
1271
    blocks_to_swap_out: List[Tuple[int,
                                   int]] = msgspec.field(default_factory=list)
1272
    # Blocks to copy. Source to dest block.
1273
    blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
1274
1275
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1276
1277
1278
1279
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1280
1281
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1282
1283
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1284
    # Finished request ids since last step.
1285
    finished_requests_ids: List[str] = msgspec.field(default_factory=list)
1286
1287
    # The last sampled token ids for multi step decoding.
    last_sampled_token_ids: Optional[torch.Tensor] = None
1288
1289
    # Async callback
    async_callback: Optional[Callable] = None
1290
1291
1292
1293
1294
1295
1296

    @property
    def is_first_multi_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1297
        assert first_seq_group.state is not None
1298
1299
1300
1301
1302
1303
1304
1305
        return first_seq_group.state.current_step == 0

    @property
    def is_last_step(self) -> bool:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
        first_seq_group = self.seq_group_metadata_list[0]
1306
        assert first_seq_group.state is not None
1307
        return first_seq_group.state.remaining_steps == 1
1308
1309
1310
1311
1312
1313

    @property
    def current_step(self) -> int:
        # TODO(will) make this be able to handle batches with variable number of
        # steps
        assert len(self.seq_group_metadata_list) > 0
1314
1315
1316
        state = self.seq_group_metadata_list[0].state
        assert state is not None
        return state.current_step
1317
1318

    def clone(
1319
1320
        self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
                                                  SequenceGroupMetadataDelta]]
1321
1322
1323
1324
1325
1326
1327
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1328
            virtual_engine=self.virtual_engine,
1329
1330
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1331
            previous_hidden_states=self.previous_hidden_states,
1332
            num_steps=self.num_steps,
1333
1334
            finished_requests_ids=self.finished_requests_ids,
            last_sampled_token_ids=self.last_sampled_token_ids.clone()
1335
            if self.last_sampled_token_ids is not None else None,
1336
            async_callback=self.async_callback)
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390


@dataclass
class SequenceGroupBase:
    group_id: str  # the original request id before splitting

    assembled_seq_group: Optional[SequenceGroup] = None

    # seq id to a unique index inside this group
    seq_id_to_index: Dict[str, int] = field(default_factory=dict)

    # seq ids to be finished
    to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)

    # seq id to finished sequences
    finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)

    streaming: bool = False

    output_produced: bool = False

    @staticmethod
    def add_request(request_id: str, engine, params, *args, **kwargs):
        """When we are ready to add a request with request_id and params
        into the engine, we can split the request into multiple requests.
        """
        raise NotImplementedError

    def finish_seq(self, seq: SequenceGroup):
        """The sequence `seq` finishes, we should record the information.
        """
        del self.to_be_finished[seq.request_id]
        self.finished_reqs[seq.request_id] = seq

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
        """Assemble the sequence group, for producing the final
        output, or adding request in the engine again.
        """
        raise NotImplementedError


class ParallelSampleSequenceGroup(SequenceGroupBase):

    @staticmethod
    def add_request(request_id: str, engine, params, **kwargs):
        original_params = params
        params = copy.deepcopy(original_params)
        params.n = 1
        group = ParallelSampleSequenceGroup(request_id)
        seqs = []
        for i in range(original_params.n):
            request_id_i = f"{request_id}_parallel_sample_{i}"
            group.seq_id_to_index[request_id_i] = i
1391
            seq_group = engine._add_processed_request(
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
                request_id_i,
                params=params,
                **kwargs,
            )  # type: ignore
            assert seq_group is not None
            engine.seq_id_to_seq_group[request_id_i] = group
            group.to_be_finished[request_id_i] = seq_group
            seqs.append(seq_group.seqs[0])

        # for parallel sampling, the `assembled_seq_group` is always
        # available, since we have all the sequences ready, and they
        # will not change.
        group.assembled_seq_group = SequenceGroup(
            request_id=request_id,
            seqs=seqs,
            arrival_time=seq_group.arrival_time,
            sampling_params=original_params,
            lora_request=seq_group.lora_request,
            embeddings=seq_group.embeddings,
            pooling_params=seq_group.pooling_params,
            encoder_seq=seq_group.encoder_seq,
            trace_headers=seq_group.trace_headers,
            prompt_adapter_request=seq_group.prompt_adapter_request,
            priority=seq_group.priority,
        )

        group.streaming = params.output_kind == RequestOutputKind.DELTA
        group.output_produced = False

    def maybe_assemble_group(
            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

        # in the streaming mode, we will return the assembled sequence
        # for the first sequence, and then return None for the rest of
        # sequences
        if self.streaming:
            if self.seq_id_to_index[seq_group.request_id] == 0:
                return self.assembled_seq_group
            return None

        # in the non-streaming mode, we will return the assembled sequence
        # once after all sequences finish, and then return None for the
        # rest of the time

        if len(self.to_be_finished) > 0:
            return None

        assert self.assembled_seq_group is not None
        params = self.assembled_seq_group.sampling_params
        assert isinstance(params, SamplingParams)
        if not self.output_produced:
            self.output_produced = True
            if params._real_n is not None:
                # Get the top-n sequences.
                n = params._real_n or params.n
                seqs = self.assembled_seq_group.seqs
                sorting_key = lambda seq: seq.get_cumulative_logprob()
                sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
                top_n_seqs = sorted_seqs[:n]
                self.assembled_seq_group.seqs = top_n_seqs
            return self.assembled_seq_group
        if self.output_produced:
            return None