sequence.py 41.6 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from abc import ABC, abstractmethod
5
from array import array
6
from collections import defaultdict
7
from dataclasses import dataclass, field
8
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
9
                    Union, cast)
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
import torch

13
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
14
from vllm.lora.request import LoRARequest
15
from vllm.pooling_params import PoolingParams
16
from vllm.prompt_adapter.request import PromptAdapterRequest
17
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
if TYPE_CHECKING:
20
    from vllm.inputs import LLMInputs
21
    from vllm.multimodal import MultiModalDataDict
22
23
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

24
25
26

@dataclass
class Logprob:
27
28
29
30
31
32
33
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
34
    logprob: float
35
    rank: Optional[int] = None
36
37
38
    decoded_token: Optional[str] = None


39
40
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
41
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
42
# {token_id -> logprob} for each sequence group.
43
SampleLogprobs = List[Dict[int, Logprob]]
44

Woosuk Kwon's avatar
Woosuk Kwon committed
45

46
class SequenceStatus(enum.IntEnum):
47
    """Status of a sequence."""
48
49
50
51
52
53
54
55
56
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
57
58
59

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
60
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
61
62
63
64
65
66
67

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
68
69
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
70
        elif status == SequenceStatus.FINISHED_IGNORED:
71
72
73
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
74
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
75
76
77
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
78

79

80
81
82
83
84
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


85
86
87
88
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

89
    Attributes:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


104
class SequenceData:
105
106
107
108
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
109
110
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
111
112
113
114
115
116

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
117
118
119
120

    def __init__(
        self,
        prompt_token_ids: List[int],
121
        output_token_ids: Optional[List[int]] = None,
122
    ) -> None:
123
        self._prompt_token_ids = array('l', prompt_token_ids)
124
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
125
126
        self._output_token_ids = array(
            'l', output_token_ids if output_token_ids is not None else [])
127

128
        self.cumulative_logprob = 0.0
129
130
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
131
        self._stage: SequenceStage = SequenceStage.PREFILL
132

133
134
135
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
136
137
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
138
139
140
141
142
143
144

    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
145
        self._prompt_token_ids = array('l', new_prompt_token_ids)
146
147
148
        self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
        self._update_cached_all_tokens()

149
150
151
152
    @property
    def prompt_token_ids_array(self) -> array:
        return self._prompt_token_ids

153
154
155
156
157
158
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
    def output_token_ids(self, new_output_token_ids) -> None:
159
        self._output_token_ids = array('l', new_output_token_ids)
160
161
        self._update_cached_all_tokens()

162
163
164
165
    @property
    def output_token_ids_array(self) -> array:
        return self._output_token_ids

166
    def append_token_id(self, token_id: int, logprob: float) -> None:
167
168
        self._output_token_ids.append(token_id)
        self._cached_all_token_ids.append(token_id)
169
        self.cumulative_logprob += logprob
170
171

    def get_len(self) -> int:
172
        return len(self._output_token_ids) + len(self._prompt_token_ids)
173

174
    def get_prompt_len(self) -> int:
175
        return len(self._prompt_token_ids)
176

177
    def get_output_len(self) -> int:
178
        return len(self._output_token_ids)
179

180
    def get_token_ids(self) -> List[int]:
181
        return self._cached_all_token_ids
182

183
184
185
186
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
187
        prompt_length = self.get_prompt_len()
188
189
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
190
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
191
192
193
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

194
195
196
197
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

198
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
199
200
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
201
202
203
204
205
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
206

207
    def reset_state_for_recompute(self) -> None:
208
209
210
211
212
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
213
        self._stage = SequenceStage.PREFILL
214
215

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
216
        """Return the number of prefill tokens that are not computed."""
217
218
219
220
221
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

222
    def get_last_token_id(self) -> int:
223
224
225
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
226

227
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
228
229
        return self.prompt_token_ids

230
    def get_output_token_ids(self) -> Tuple[int, ...]:
231
232
        return self.output_token_ids

233
234
235
236
    @property
    def stage(self) -> SequenceStage:
        return self._stage

237
238
    def __repr__(self) -> str:
        return (f"SequenceData("
239
240
                f"prompt_token_ids={self._prompt_token_ids}, "
                f"output_token_ids={self._output_token_ids}, "
241
                f"cumulative_logprob={self.cumulative_logprob})")
242
243


Woosuk Kwon's avatar
Woosuk Kwon committed
244
class Sequence:
245
246
    """Stores the data, status, and block information of a sequence.

247
248
249
250
251
252
253
254
255
    The sequence is constructed from the LLMInputs instance passed
    in through the `inputs` constructor argument.

    For encoder/decoder models, LLMInputs encapsulates both a
    decoder and encoder prompt, creating an ambiguity about which
    prompt to construct the sequence from. The `from_decoder_prompt`
    constructor argument signals whether to construct the Sequence
    from the LLMInputs decoder prompt, or encoder prompt.

256
257
    Args:
        seq_id: The ID of the sequence.
258
        inputs: The inputs of the sequence.
259
260
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
261
        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
262
        lora_request: LoRA request.
263
        prompt_adapter_request: Prompt Adapter request.
264
265
266
        from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
                             (True) or encoder prompt (False.) Must be True
                             for decoder-only model.
267

268
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270

    def __init__(
271
272
273
274
275
276
277
278
        self,
        seq_id: int,
        inputs: "LLMInputs",
        block_size: int,
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        from_decoder_prompt: bool = True,
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
    ) -> None:
        self.seq_id = seq_id
281
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
282
        self.block_size = block_size
283
        self.eos_token_id = eos_token_id
284
        self.lora_request = lora_request
285
        self.prompt_adapter_request = prompt_adapter_request
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        self.from_decoder_prompt = from_decoder_prompt
        self._prompt: Optional[str] = None
        self._prompt_token_ids: Optional[List[int]] = None

        # For decoder-only models, a Sequence is constructed
        # from an LLMInputs instance (the `inputs` arg.)
        #
        # For encoder/decoder models the same `inputs`
        # instance could be utilized to construct either an
        # encoder sequence or a decoder sequence, because
        # `LLMInputs` has both decoder- and encoder-oriented
        # member variables (i.e. it encapsulates both an encoder
        # and a decoder prompt.) The decision of which type of sequence
        # to generate is determined by the `from_decoder_prompt` argument.
        #
        # When constructing a encoder sequence
        # (`from_decoder_prompt` False) it matters that
        # the `LLMInputs` instance stored in `inputs` is valid
        # in the sense that its encoder-related member variables are
        # populated; below, an exception is raised if this is
        # not the case.
        #
        # When constructing a decoder sequence (`from_decoder_prompt` True)
        # it does not matter whether `inputs` has its encoder-related
        # member variables populated.
        if not (from_decoder_prompt
                or is_valid_encoder_decoder_llm_inputs(inputs)):
            raise ValueError("Cannot extract encoder input prompt from "
                             f"invalid input {inputs}; did you forget the "
                             "encoder input prompt fields?")
Woosuk Kwon's avatar
Woosuk Kwon committed
316

317
        self.data = SequenceData(self.prompt_token_ids)
318
        self.output_logprobs: SampleLogprobs = []
319
        self.output_text = ""
320

321
        self.status = SequenceStatus.WAITING
322
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
323

324
325
326
327
328
329
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

330
331
    @property
    def n_blocks(self) -> int:
332
        return (self.get_len() + self.block_size - 1) // self.block_size
333

334
335
    @property
    def prompt(self) -> Optional[str]:
336
337
338
339
340
341
342
343
344
345
346
347
        if self._prompt is not None:
            # Reuse precomputed prompt string
            return self._prompt

        # Select decoder or encoder input prompt str,
        # as appropriate
        prompt_key: str = ("prompt"
                           if self.from_decoder_prompt else "encoder_prompt")

        # Cache prompt
        self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
        return self._prompt
348
349
350

    @property
    def prompt_token_ids(self) -> List[int]:
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        if self._prompt_token_ids is not None:
            # Reuse precomputed prompt token ids
            return self._prompt_token_ids

        # Select decoder or encoder input prompt
        # token ids, as appropriate
        prompt_token_ids_key: str = ("prompt_token_ids"
                                     if self.from_decoder_prompt else
                                     "encoder_prompt_token_ids")

        # Cache computed prompt token ids
        self._prompt_token_ids = cast(List[int],
                                      self.inputs.get(prompt_token_ids_key))
        return self._prompt_token_ids
365
366

    @property
367
368
    def multi_modal_data(self) -> "MultiModalDataDict":
        return self.inputs.get("multi_modal_data") or {}
369

370
371
372
373
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

374
375
376
377
378
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

379
380
381
382
383
384
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

385
    def hash_of_block(self, logical_idx: int) -> int:
386
387
        # TODO This can produce incorrect hash when block size > prompt size

388
        # Compute the number of tokens in the sequence
389
390
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
391
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
392
393
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
394
395
396
397

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

398
399
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
400
        self.data.reset_state_for_recompute()
401

402
403
404
    def append_token_id(
        self,
        token_id: int,
405
        logprobs: Dict[int, Logprob],
406
    ) -> None:
407
408
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
409
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
410

Woosuk Kwon's avatar
Woosuk Kwon committed
411
    def get_len(self) -> int:
412
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
413

414
415
416
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

417
418
419
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
420
    def get_token_ids(self) -> List[int]:
421
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
422

423
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
424
425
        return self.data.get_prompt_token_ids()

426
    def get_last_token_id(self) -> int:
427
        return self.data.get_last_token_id()
428

429
430
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
431
432
433
434

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

435
    def get_beam_search_score(self,
436
                              length_penalty: float = 1.0,
437
438
439
440
441
442
443
444
445
446
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
447
            # NOTE: HF implementation does not count the EOS token
448
449
450
451
452
453
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

454
455
456
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

457
458
459
460
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
461

462
463
464
465
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
466
467
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
468
469
470
471
472
473
474
475
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
476
    def __repr__(self) -> str:
477
478
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
479
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
480

Woosuk Kwon's avatar
Woosuk Kwon committed
481
482

class SequenceGroup:
483
484
485
486
487
488
489
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
490
        lora_request: LoRA request.
491
492
493
494
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
495
496
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
497
        trace_headers: OpenTelemetry trace headers.
498
        prompt_adapter_request: Prompt Adapter request.
499
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
500
501
502

    def __init__(
        self,
503
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
504
        seqs: List[Sequence],
505
        arrival_time: float,
506
        sampling_params: Optional[SamplingParams] = None,
507
        lora_request: Optional[LoRARequest] = None,
508
509
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
510
        encoder_seq: Optional[Sequence] = None,
511
        trace_headers: Optional[Mapping[str, str]] = None,
512
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
513
    ) -> None:
514
        self.request_id = request_id
515
        self.seqs = seqs
516
        self.is_single_seq = len(seqs) == 1
517
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
518

519
        self.sampling_params = sampling_params
520
521
522
523
524
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
525
        self.lora_request = lora_request
526
        self.prompt_logprobs: Optional[PromptLogprobs] = None
527
528
        self.embeddings = embeddings
        self.pooling_params = pooling_params
529
        self.prompt_adapter_request = prompt_adapter_request
530
        self.encoder_seq = encoder_seq
531
        self.trace_headers = trace_headers
532
533

    @property
534
    def prompt(self) -> Optional[str]:
535
536
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
537
        return self.seqs[0].prompt
538
539
540
541
542

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
543
        return self.seqs[0].prompt_token_ids
544

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    @property
    def encoder_prompt(self) -> Optional[str]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt is distinct
        # from the decoder's.
        return (self.encoder_seq.prompt
                if self.encoder_seq is not None else None)

    @property
    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
        # There are either 0 or 1 encoder sequences
        # If one is present, its prompt token ids are
        # distinct from the decoder's.
        return (self.encoder_seq.prompt_token_ids
                if self.encoder_seq is not None else None)

561
    @property
562
    def multi_modal_data(self) -> "MultiModalDataDict":
563
564
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
565
        return self.seqs[0].multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
566

567
568
569
570
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

571
572
573
574
575
576
577
578
579
580
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

581
582
583
584
585
586
587
588
589
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
590
591
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
592
593
        return latency

594
595
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
596
597
598
599
600
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
601
                and self.seqs[0].get_output_len() == 1):
602
603
604
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
605
606
        """Sets the first scheduled time and time in queue for Request
        level timings."""
607
608
609
610
611
612
613
614
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

615
616
617
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
618
        if self.sampling_params and self.sampling_params.use_beam_search:
619
620
621
622
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
623
624
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
625
626
627
628
629
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
630
631
            # that are not finished yet.
            return self.num_unfinished_seqs()
632

633
634
635
636
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
637
638
        if status is None:
            return self.seqs
639
640
641
642

        if self.is_single_seq:
            return self.seqs if self.seqs[0].status == status else []

643
        return [seq for seq in self.seqs if seq.status == status]
644

645
646
647
648
649
650
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

651
    def get_unfinished_seqs(self) -> List[Sequence]:
652
653
654
        if self.is_single_seq:
            return self.seqs if not self.seqs[0].is_finished() else []

655
        return [seq for seq in self.seqs if not seq.is_finished()]
656

657
    def get_finished_seqs(self) -> List[Sequence]:
658
659
660
        if self.is_single_seq:
            return self.seqs if self.seqs[0].is_finished() else []

661
        return [seq for seq in self.seqs if seq.is_finished()]
662

663
664
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
665
        for seq in self.seqs:
666
667
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
668
669

    def get_num_uncomputed_tokens(self) -> int:
670
        num_uncomputed_tokens = 0
671
        for seq in self.seqs:
672
673
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
674
        return num_uncomputed_tokens
675

676
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
677
678
679
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
680
            return len(self.seqs)
681

682
683
684
        if self.is_single_seq:
            return 1 if self.seqs[0].status == status else 0

685
        return len(self.get_seqs(status))
686

687
    def num_unfinished_seqs(self) -> int:
688
689
690
        if self.is_single_seq:
            return 1 if not self.seqs[0].is_finished() else 0

691
692
693
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
694
695
696
        if self.is_single_seq:
            return 1 if self.seqs[0].is_finished() else 0

697
698
        return len(self.get_finished_seqs())

699
    def find(self, seq_id: int) -> Sequence:
700
701
702
703
704
705
706
707
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq
708
        self.seqs.append(seq)
709
        self.is_single_seq = len(self.seqs) == 1
710
711

    def remove(self, seq_id: int) -> None:
712
713
        seq = self.seqs_dict.pop(seq_id, None)
        if seq is None:
714
            raise ValueError(f"Sequence {seq_id} not found.")
715
        self.seqs.remove(seq)
716
        self.is_single_seq = len(self.seqs) == 1
Woosuk Kwon's avatar
Woosuk Kwon committed
717

Woosuk Kwon's avatar
Woosuk Kwon committed
718
    def is_finished(self) -> bool:
719
        return all(seq.is_finished() for seq in self.seqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
720

721
    def is_prefill(self) -> bool:
722
        # Every sequence should be in the same stage.
723
        return self.seqs[0].is_prefill()
724

Woosuk Kwon's avatar
Woosuk Kwon committed
725
    def __repr__(self) -> str:
726
727
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
728
                f"num_seqs={len(self.seqs)})")
729
730


731
class SequenceGroupMetadata:
732
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
733
734
735
736
737
738
739
740

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
741
742
743
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
744
745
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
746
        lora_request: LoRA request.
747
748
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
749
        multi_modal_data: Multi modal data.
750
751
752
753
754
755
756
757
758
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
759
        prompt_adapter_request: Prompt Adapter request.
760
    """
761
762
763

    def __init__(
        self,
764
        request_id: str,
765
        is_prompt: bool,
766
        seq_data: Dict[int, SequenceData],
767
        sampling_params: SamplingParams,
768
        block_tables: Dict[int, List[int]],
769
        do_sample: bool = True,
770
        pooling_params: Optional[PoolingParams] = None,
771
        token_chunk_size: Optional[int] = None,
772
        lora_request: Optional[LoRARequest] = None,
773
        computed_block_nums: Optional[List[int]] = None,
774
        multi_modal_data: Optional["MultiModalDataDict"] = None,
775
776
        encoder_seq_data: Optional[SequenceData] = None,
        cross_block_table: Optional[List[int]] = None,
777
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
778
    ) -> None:
779
        self.request_id = request_id
780
        self.is_prompt = is_prompt
781
        self.seq_data = seq_data
782
783
        self.sampling_params = sampling_params
        self.block_tables = block_tables
784
        self.pooling_params = pooling_params
785
        self.lora_request = lora_request
786
        self.prompt_adapter_request = prompt_adapter_request
787
        self.computed_block_nums = computed_block_nums
788
        self.multi_modal_data = multi_modal_data
789
790
        self.encoder_seq_data = encoder_seq_data
        self.cross_block_table = cross_block_table
791
        self._token_chunk_size = token_chunk_size
792
        self.do_sample = do_sample
793

794
795
796
797
798
799
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

800
        if seq_data is not None and self._token_chunk_size is None:
801
            if is_prompt:
802
803
                self._token_chunk_size = next(iter(
                    seq_data.values())).get_len()
804
805
            else:
                self._token_chunk_size = 1
806

807
808
809
810
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

811
    @property
812
813
814
815
816
817
818
819
820
821
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

    @property
822
    def token_chunk_size(self) -> int:
823
        """Return the number of tokens to be processed (chunk size)."""
824
        assert self._token_chunk_size is not None
825
826
        return self._token_chunk_size

827

Zhuohan Li's avatar
Zhuohan Li committed
828
class SequenceOutput:
829
830
831
832
833
834
835
836
837
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
838
839
840
841
842

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
843
        logprobs: Dict[int, Logprob],
844
845
846
847
848
849
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
850
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
851
852
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
853

854
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
855
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
856
            raise NotImplementedError()
857
858
859
860
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
861
862


863
864
865
866
867
868
869
870
871
872
873
874
875
876
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
877
878
879

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
880
        samples: List[SequenceOutput],
881
882
883
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
884
        # Prompt logprob for each prompt query token.
885
886
887
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
888
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
889
890
                f"prompt_logprobs={self.prompt_logprobs})")

891
    def __eq__(self, other: object) -> bool:
892
        if not isinstance(other, CompletionSequenceGroupOutput):
893
894
895
896
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

897

898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
@dataclass
class IntermediateTensors:
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


945
946
947
948
949
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

950
    This data structure implements methods, so it can be used like a list, but
951
952
953
    also has optional fields for device tensors.
    """

954
    outputs: List[CompletionSequenceGroupOutput]
955
956

    # On-device tensor containing probabilities of each token.
957
    sampled_token_probs: Optional[torch.Tensor] = None
958

959
960
961
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

962
    # On-device tensor containing the sampled token ids.
963
    sampled_token_ids: Optional[torch.Tensor] = None
964
965
966
967

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

968
969
970
    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

971
972
973
974
975
976
977
978
979
980
981
982
    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
983
984
985
986
987
988
989
990
991
992
993
994
995

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
996
997


998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


1019
1020
1021
1022
1023
1024
1025
1026
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
class HiddenStates:
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""

    def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
                 hidden_states: torch.Tensor):
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
        self.hidden_states: torch.Tensor = hidden_states

    def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
               hidden_states: torch.Tensor) -> None:
        """Update hidden states from target model invocation."""
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        """Prune to provided list of sequence ids."""
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self.seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
            self.hidden_states = self.hidden_states[index]
            self.seq_ids = seq_ids


1074
1075
@dataclass
class ExecuteModelRequest:
1076
1077
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
1078
1079
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
1080
1081
1082
1083
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
1084
1085
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
1086
1087
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
1088
1089
1090
1091
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
1092
1093
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
1094
1095
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1096
1097
    # Finished request ids since last step.
    finished_requests_ids: List[str] = field(default_factory=list)
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1108
            virtual_engine=self.virtual_engine,
1109
1110
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1111
            previous_hidden_states=self.previous_hidden_states,
1112
            num_steps=self.num_steps,
Mor Zusman's avatar
Mor Zusman committed
1113
            finished_requests_ids=self.finished_requests_ids)