sequence.py 40.9 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
import math
5
from abc import ABC, abstractmethod
6
from array import array
7
from collections import defaultdict
8
from dataclasses import dataclass, field
9
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
10
                    Union, cast)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
13
import torch

14
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
15
from vllm.lora.request import LoRARequest
16
from vllm.pooling_params import PoolingParams
17
from vllm.prompt_adapter.request import PromptAdapterRequest
18
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
if TYPE_CHECKING:
21
    from vllm.inputs import LLMInputs
22
    from vllm.multimodal import MultiModalDataDict
23
24
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

25
26
27

@dataclass
class Logprob:
28
29
30
31
32
33
34
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
35
    logprob: float
36
    rank: Optional[int] = None
37
38
39
    decoded_token: Optional[str] = None


40
41
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
42
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
43
# {token_id -> logprob} for each sequence group.
44
SampleLogprobs = List[Dict[int, Logprob]]
45

Woosuk Kwon's avatar
Woosuk Kwon committed
46

47
class SequenceStatus(enum.IntEnum):
48
    """Status of a sequence."""
49
50
51
52
53
54
55
56
57
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
58
59
60

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
61
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
62
63
64
65
66
67
68

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
69
70
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
71
        elif status == SequenceStatus.FINISHED_IGNORED:
72
73
74
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
75
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
76
77
78
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
79

80

81
82
83
84
85
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


86
87
88
89
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

90
    Attributes:
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


105
class SequenceData:
106
107
108
109
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
110
111
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
112
113
114
115
116
117

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
118
119
120
121

    def __init__(
        self,
        prompt_token_ids: List[int],
122
        output_token_ids: Optional[List[int]] = None,
123
    ) -> None:
124
        self._prompt_token_ids = array('l', prompt_token_ids)
125
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
126
127
        self._output_token_ids = array(
            'l', output_token_ids if output_token_ids is not None else [])
128

129
        self.cumulative_logprob = 0.0
130
131
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
132
        self._stage: SequenceStage = SequenceStage.PREFILL
133

134
135
136
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
137
138
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
139
140
141
142
143
144
145

    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
146
        self._prompt_token_ids = array('l', new_prompt_token_ids)
147
148
149
        self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
        self._update_cached_all_tokens()

150
151
152
153
    @property
    def prompt_token_ids_array(self) -> array:
        return self._prompt_token_ids

154
155
156
157
158
159
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
    def output_token_ids(self, new_output_token_ids) -> None:
160
        self._output_token_ids = array('l', new_output_token_ids)
161
162
        self._update_cached_all_tokens()

163
164
165
166
    @property
    def output_token_ids_array(self) -> array:
        return self._output_token_ids

167
    def append_token_id(self, token_id: int, logprob: float) -> None:
168
169
        self._output_token_ids.append(token_id)
        self._cached_all_token_ids.append(token_id)
170
        self.cumulative_logprob += logprob
171
172

    def get_len(self) -> int:
173
        return len(self._output_token_ids) + len(self._prompt_token_ids)
174

175
    def get_prompt_len(self) -> int:
176
        return len(self._prompt_token_ids)
177

178
    def get_output_len(self) -> int:
179
        return len(self._output_token_ids)
180

181
    def get_token_ids(self) -> List[int]:
182
        return self._cached_all_token_ids
183

184
185
186
187
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
188
        prompt_length = self.get_prompt_len()
189
190
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
191
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
192
193
194
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

195
196
197
198
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

199
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
200
201
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
202
203
204
205
206
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
207

208
    def reset_state_for_recompute(self) -> None:
209
210
211
212
213
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
214
        self._stage = SequenceStage.PREFILL
215
216

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
217
        """Return the number of prefill tokens that are not computed."""
218
219
220
221
222
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

223
    def get_last_token_id(self) -> int:
224
225
226
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
227

228
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
229
230
        return self.prompt_token_ids

231
    def get_output_token_ids(self) -> Tuple[int, ...]:
232
233
        return self.output_token_ids

234
235
236
237
    @property
    def stage(self) -> SequenceStage:
        return self._stage

238
239
    def __repr__(self) -> str:
        return (f"SequenceData("
240
241
                f"prompt_token_ids={self._prompt_token_ids}, "
                f"output_token_ids={self._output_token_ids}, "
242
                f"cumulative_logprob={self.cumulative_logprob})")
243
244


Woosuk Kwon's avatar
Woosuk Kwon committed
245
class Sequence:
246
247
    """Stores the data, status, and block information of a sequence.

248
249
250
251
252
253
254
255
256
    The sequence is constructed from the LLMInputs instance passed
    in through the `inputs` constructor argument.

    For encoder/decoder models, LLMInputs encapsulates both a
    decoder and encoder prompt, creating an ambiguity about which
    prompt to construct the sequence from. The `from_decoder_prompt`
    constructor argument signals whether to construct the Sequence
    from the LLMInputs decoder prompt, or encoder prompt.

257
258
    Args:
        seq_id: The ID of the sequence.
259
        inputs: The inputs of the sequence.
260
261
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
262
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
263
        lora_request: LoRA request.
264
        prompt_adapter_request: Prompt Adapter request.
265
266
267
        from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
                             (True) or encoder prompt (False.) Must be True
                             for decoder-only model.
268

269
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271

    def __init__(
272
273
274
275
276
277
278
279
        self,
        seq_id: int,
        inputs: "LLMInputs",
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        from_decoder_prompt: bool = True,
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
    ) -> None:
        self.seq_id = seq_id
282
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
283
        self.block_size = block_size
284
        self.eos_token_id = eos_token_id
285
        self.lora_request = lora_request
286
        self.prompt_adapter_request = prompt_adapter_request
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        self.from_decoder_prompt = from_decoder_prompt
        self._prompt: Optional[str] = None
        self._prompt_token_ids: Optional[List[int]] = None

        # For decoder-only models, a Sequence is constructed
        # from an LLMInputs instance (the `inputs` arg.)
        #
        # For encoder/decoder models the same `inputs`
        # instance could be utilized to construct either an
        # encoder sequence or a decoder sequence, because
        # `LLMInputs` has both decoder- and encoder-oriented
        # member variables (i.e. it encapsulates both an encoder
        # and a decoder prompt.) The decision of which type of sequence
        # to generate is determined by the `from_decoder_prompt` argument.
        #
        # When constructing a encoder sequence
        # (`from_decoder_prompt` False) it matters that
        # the `LLMInputs` instance stored in `inputs` is valid
        # in the sense that its encoder-related member variables are
        # populated; below, an exception is raised if this is
        # not the case.
        #
        # When constructing a decoder sequence (`from_decoder_prompt` True)
        # it does not matter whether `inputs` has its encoder-related
        # member variables populated.
        if not (from_decoder_prompt
                or is_valid_encoder_decoder_llm_inputs(inputs)):
            raise ValueError("Cannot extract encoder input prompt from "
                             f"invalid input {inputs}; did you forget the "
                             "encoder input prompt fields?")
Woosuk Kwon's avatar
Woosuk Kwon committed
317

318
        self.data = SequenceData(self.prompt_token_ids)
319
        self.output_logprobs: SampleLogprobs = []
320
        self.output_text = ""
321

322
        self.status = SequenceStatus.WAITING
323
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
324

325
326
327
328
329
330
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

331
332
333
334
    @property
    def n_blocks(self) -> int:
        return math.ceil(self.get_len() / self.block_size)

335
336
    @property
    def prompt(self) -> Optional[str]:
337
338
339
340
341
342
343
344
345
346
347
348
        if self._prompt is not None:
            # Reuse precomputed prompt string
            return self._prompt

        # Select decoder or encoder input prompt str,
        # as appropriate
        prompt_key: str = ("prompt"
                           if self.from_decoder_prompt else "encoder_prompt")

        # Cache prompt
        self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
        return self._prompt
349
350
351

    @property
    def prompt_token_ids(self) -> List[int]:
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        if self._prompt_token_ids is not None:
            # Reuse precomputed prompt token ids
            return self._prompt_token_ids

        # Select decoder or encoder input prompt
        # token ids, as appropriate
        prompt_token_ids_key: str = ("prompt_token_ids"
                                     if self.from_decoder_prompt else
                                     "encoder_prompt_token_ids")

        # Cache computed prompt token ids
        self._prompt_token_ids = cast(List[int],
                                      self.inputs.get(prompt_token_ids_key))
        return self._prompt_token_ids
366
367

    @property
368
369
    def multi_modal_data(self) -> "MultiModalDataDict":
        return self.inputs.get("multi_modal_data") or {}
370

371
372
373
374
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

375
376
377
378
379
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

380
381
382
383
384
385
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

386
    def hash_of_block(self, logical_idx: int) -> int:
387
388
        # TODO This can produce incorrect hash when block size > prompt size

389
        # Compute the number of tokens in the sequence
390
391
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
392
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
393
394
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
395
396
397
398

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

399
400
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
401
        self.data.reset_state_for_recompute()
402

403
404
405
    def append_token_id(
        self,
        token_id: int,
406
        logprobs: Dict[int, Logprob],
407
    ) -> None:
408
409
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
410
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
411

Woosuk Kwon's avatar
Woosuk Kwon committed
412
    def get_len(self) -> int:
413
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
414

415
416
417
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

418
419
420
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
421
    def get_token_ids(self) -> List[int]:
422
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
423

424
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
425
426
        return self.data.get_prompt_token_ids()

427
    def get_last_token_id(self) -> int:
428
        return self.data.get_last_token_id()
429

430
431
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
432
433
434
435

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

436
    def get_beam_search_score(self,
437
                              length_penalty: float = 1.0,
438
439
440
441
442
443
444
445
446
447
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
448
            # NOTE: HF implementation does not count the EOS token
449
450
451
452
453
454
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

455
456
457
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

458
459
460
461
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
462

463
464
465
466
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
467
468
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
469
470
471
472
473
474
475
476
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
477
    def __repr__(self) -> str:
478
479
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
480
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
481

Woosuk Kwon's avatar
Woosuk Kwon committed
482
483

class SequenceGroup:
484
485
486
487
488
489
490
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
491
        lora_request: LoRA request.
492
493
494
495
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
496
497
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
498
        trace_headers: OpenTelemetry trace headers.
499
        prompt_adapter_request: Prompt Adapter request.
500
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
501
502
503

    def __init__(
        self,
504
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
505
        seqs: List[Sequence],
506
        arrival_time: float,
507
        sampling_params: Optional[SamplingParams] = None,
508
        lora_request: Optional[LoRARequest] = None,
509
510
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
511
        encoder_seq: Optional[Sequence] = None,
512
        trace_headers: Optional[Mapping[str, str]] = None,
513
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
514
    ) -> None:
515
        self.request_id = request_id
516
        self.seqs = seqs
517
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
518
        self.sampling_params = sampling_params
519
520
521
522
523
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
524
        self.lora_request = lora_request
525
        self.prompt_logprobs: Optional[PromptLogprobs] = None
526
527
        self.embeddings = embeddings
        self.pooling_params = pooling_params
528
        self.prompt_adapter_request = prompt_adapter_request
529
        self.encoder_seq = encoder_seq
530
        self.trace_headers = trace_headers
531
532

    @property
533
    def prompt(self) -> Optional[str]:
534
535
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
536
        return self.seqs[0].prompt
537
538
539
540
541

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
542
        return self.seqs[0].prompt_token_ids
543

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

560
    @property
561
    def multi_modal_data(self) -> "MultiModalDataDict":
562
563
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
564
        return self.seqs[0].multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
565

566
567
568
569
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

570
571
572
573
574
575
576
577
578
579
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

580
581
582
583
584
585
586
587
588
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
589
590
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
591
592
        return latency

593
594
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
595
596
597
598
599
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
600
                and self.seqs[0].get_output_len() == 1):
601
602
603
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
604
605
        """Sets the first scheduled time and time in queue for Request
        level timings."""
606
607
608
609
610
611
612
613
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

614
615
616
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
617
        if self.sampling_params and self.sampling_params.use_beam_search:
618
619
620
621
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
622
623
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
624
625
626
627
628
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
629
630
            # that are not finished yet.
            return self.num_unfinished_seqs()
631

632
633
634
635
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
636
637
638
        if status is None:
            return self.seqs
        return [seq for seq in self.seqs if seq.status == status]
639

640
641
642
643
644
645
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

646
    def get_unfinished_seqs(self) -> List[Sequence]:
647
        return [seq for seq in self.seqs if not seq.is_finished()]
648

649
    def get_finished_seqs(self) -> List[Sequence]:
650
        return [seq for seq in self.seqs if seq.is_finished()]
651

652
653
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
654
        for seq in self.seqs:
655
656
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
657
658

    def get_num_uncomputed_tokens(self) -> int:
659
        num_uncomputed_tokens = 0
660
        for seq in self.seqs:
661
662
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
663
        return num_uncomputed_tokens
664

665
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
666
667
668
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
669
            return len(self.seqs)
670

671
        return len(self.get_seqs(status))
672

673
674
675
676
677
678
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

679
    def find(self, seq_id: int) -> Sequence:
680
681
682
683
684
685
686
687
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq
688
        self.seqs.append(seq)
689
690

    def remove(self, seq_id: int) -> None:
691
692
        seq = self.seqs_dict.pop(seq_id, None)
        if seq is None:
693
            raise ValueError(f"Sequence {seq_id} not found.")
694
        self.seqs.remove(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
695

Woosuk Kwon's avatar
Woosuk Kwon committed
696
    def is_finished(self) -> bool:
697
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
698

699
    def is_prefill(self) -> bool:
700
        # Every sequence should be in the same stage.
701
        return self.seqs[0].is_prefill()
702

Woosuk Kwon's avatar
Woosuk Kwon committed
703
    def __repr__(self) -> str:
704
705
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
706
                f"num_seqs={len(self.seqs)})")
707
708


709
class SequenceGroupMetadata:
710
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
711
712
713
714
715
716
717
718

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
719
720
721
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
722
723
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
724
        lora_request: LoRA request.
725
726
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
727
        multi_modal_data: Multi modal data.
728
729
730
731
732
733
734
735
736
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
737
        prompt_adapter_request: Prompt Adapter request.
738
    """
739
740
741

    def __init__(
        self,
742
        request_id: str,
743
        is_prompt: bool,
744
        seq_data: Dict[int, SequenceData],
745
        sampling_params: SamplingParams,
746
        block_tables: Dict[int, List[int]],
747
        do_sample: bool = True,
748
        pooling_params: Optional[PoolingParams] = None,
749
        token_chunk_size: Optional[int] = None,
750
        lora_request: Optional[LoRARequest] = None,
751
        computed_block_nums: Optional[List[int]] = None,
752
        multi_modal_data: Optional["MultiModalDataDict"] = None,
753
754
        encoder_seq_data: Optional[SequenceData] = None,
        cross_block_table: Optional[List[int]] = None,
755
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
756
    ) -> None:
757
        self.request_id = request_id
758
        self.is_prompt = is_prompt
759
        self.seq_data = seq_data
760
761
        self.sampling_params = sampling_params
        self.block_tables = block_tables
762
        self.pooling_params = pooling_params
763
        self.lora_request = lora_request
764
        self.prompt_adapter_request = prompt_adapter_request
765
        self.computed_block_nums = computed_block_nums
766
        self.multi_modal_data = multi_modal_data
767
768
        self.encoder_seq_data = encoder_seq_data
        self.cross_block_table = cross_block_table
769
        self._token_chunk_size = token_chunk_size
770
        self.do_sample = do_sample
771

772
773
774
775
776
777
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

778
779
780
781
782
        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
783

784
785
786
787
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

788
    @property
789
790
791
792
793
794
795
796
797
798
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

    @property
799
    def token_chunk_size(self) -> int:
800
        """Return the number of tokens to be processed (chunk size)."""
801
        assert self._token_chunk_size is not None
802
803
        return self._token_chunk_size

804

Zhuohan Li's avatar
Zhuohan Li committed
805
class SequenceOutput:
806
807
808
809
810
811
812
813
814
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
815
816
817
818
819

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
820
        logprobs: Dict[int, Logprob],
821
822
823
824
825
826
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
827
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
828
829
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
830

831
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
832
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
833
            raise NotImplementedError()
834
835
836
837
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
838
839


840
841
842
843
844
845
846
847
848
849
850
851
852
853
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
854
855
856

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
857
        samples: List[SequenceOutput],
858
859
860
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
861
        # Prompt logprob for each prompt query token.
862
863
864
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
865
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
866
867
                f"prompt_logprobs={self.prompt_logprobs})")

868
    def __eq__(self, other: object) -> bool:
869
        if not isinstance(other, CompletionSequenceGroupOutput):
870
871
872
873
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

874

875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
@dataclass
class IntermediateTensors:
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


922
923
924
925
926
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

927
    This data structure implements methods, so it can be used like a list, but
928
929
930
    also has optional fields for device tensors.
    """

931
    outputs: List[CompletionSequenceGroupOutput]
932
933

    # On-device tensor containing probabilities of each token.
934
    sampled_token_probs: Optional[torch.Tensor] = None
935

936
937
938
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

939
    # On-device tensor containing the sampled token ids.
940
    sampled_token_ids: Optional[torch.Tensor] = None
941
942
943
944

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

945
946
947
    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

948
949
950
951
952
953
954
955
956
957
958
959
    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
960
961
962
963
964
965
966
967
968
969
970
971
972

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
973
974


975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


996
997
998
999
1000
1001
1002
1003
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
class HiddenStates:
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""

    def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
                 hidden_states: torch.Tensor):
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
        self.hidden_states: torch.Tensor = hidden_states

    def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
               hidden_states: torch.Tensor) -> None:
        """Update hidden states from target model invocation."""
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        """Prune to provided list of sequence ids."""
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self.seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
            self.hidden_states = self.hidden_states[index]
            self.seq_ids = seq_ids


1051
1052
@dataclass
class ExecuteModelRequest:
1053
1054
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1055
1056
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
1057
1058
1059
1060
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
1061
1062
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
1063
1064
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1065
1066
1067
1068
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1069
1070
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1071
1072
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1073
1074
    # Finished request ids since last step.
    finished_requests_ids: List[str] = field(default_factory=list)
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1085
            virtual_engine=self.virtual_engine,
1086
1087
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1088
            previous_hidden_states=self.previous_hidden_states,
1089
            num_steps=self.num_steps,
Mor Zusman's avatar
Mor Zusman committed
1090
            finished_requests_ids=self.finished_requests_ids)