sequence.py 37.3 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
import math
5
from abc import ABC, abstractmethod
6
from collections import defaultdict
7
from dataclasses import dataclass, field
8
9
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
                    Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
import torch

13
from vllm.lora.request import LoRARequest
14
from vllm.pooling_params import PoolingParams
15
from vllm.prompt_adapter.request import PromptAdapterRequest
16
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
if TYPE_CHECKING:
19
    from vllm.inputs import LLMInputs
20
    from vllm.multimodal import MultiModalDataDict
21
22
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

23
24
25

@dataclass
class Logprob:
26
27
28
29
30
31
32
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
33
    logprob: float
34
    rank: Optional[int] = None
35
36
37
    decoded_token: Optional[str] = None


38
39
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
40
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
41
# {token_id -> logprob} for each sequence group.
42
SampleLogprobs = List[Dict[int, Logprob]]
43

Woosuk Kwon's avatar
Woosuk Kwon committed
44

45
class SequenceStatus(enum.IntEnum):
46
    """Status of a sequence."""
47
48
49
50
51
52
53
54
55
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
56
57
58

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
59
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
60
61
62
63
64
65
66

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
67
68
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
69
        elif status == SequenceStatus.FINISHED_IGNORED:
70
71
72
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
73
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
74
75
76
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
77

78

79
80
81
82
83
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


84
85
86
87
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

88
    Attributes:
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


103
class SequenceData:
104
105
106
107
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
108
109
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
110
111
112
113
114
115

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
116
117
118
119

    def __init__(
        self,
        prompt_token_ids: List[int],
120
        output_token_ids: Optional[List[int]] = None,
121
    ) -> None:
122
123
124
125
        self._prompt_token_ids: List[int] = list(prompt_token_ids)
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
        self._output_token_ids: List[int] = (
            list(output_token_ids) if output_token_ids is not None else [])
126

127
        self.cumulative_logprob = 0.0
128
129
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
130
        self._stage: SequenceStage = SequenceStage.PREFILL
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
        self._cached_all_token_ids: List[int] = (self._prompt_token_ids +
                                                 self._output_token_ids)

    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
        self._prompt_token_ids = list(new_prompt_token_ids)
        self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
        self._update_cached_all_tokens()

    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
    def output_token_ids(self, new_output_token_ids) -> None:
        self._output_token_ids = list(new_output_token_ids)
        self._update_cached_all_tokens()

157
    def append_token_id(self, token_id: int, logprob: float) -> None:
158
159
        self._output_token_ids.append(token_id)
        self._cached_all_token_ids.append(token_id)
160
        self.cumulative_logprob += logprob
161
162

    def get_len(self) -> int:
163
        return len(self._output_token_ids) + len(self._prompt_token_ids)
164

165
    def get_prompt_len(self) -> int:
166
        return len(self._prompt_token_ids)
167

168
    def get_output_len(self) -> int:
169
        return len(self._output_token_ids)
170

171
    def get_token_ids(self) -> List[int]:
172
        return self._cached_all_token_ids
173

174
175
176
177
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
178
        prompt_length = self.get_prompt_len()
179
180
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
181
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
182
183
184
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

185
186
187
188
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

189
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
190
191
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
192
193
194
195
196
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
197

198
    def reset_state_for_recompute(self) -> None:
199
200
201
202
203
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
204
        self._stage = SequenceStage.PREFILL
205
206

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
207
        """Return the number of prefill tokens that are not computed."""
208
209
210
211
212
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

213
    def get_last_token_id(self) -> int:
214
215
216
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
217

218
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
219
220
        return self.prompt_token_ids

221
    def get_output_token_ids(self) -> Tuple[int, ...]:
222
223
        return self.output_token_ids

224
225
226
227
    @property
    def stage(self) -> SequenceStage:
        return self._stage

228
229
    def __repr__(self) -> str:
        return (f"SequenceData("
230
231
                f"prompt_token_ids={self._prompt_token_ids}, "
                f"output_token_ids={self._output_token_ids}, "
232
                f"cumulative_logprob={self.cumulative_logprob})")
233
234


Woosuk Kwon's avatar
Woosuk Kwon committed
235
class Sequence:
236
237
238
239
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
240
        inputs: The inputs of the sequence.
241
242
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
243
        lora_request: LoRA request.
244
245
        prompt_adapter_request: Prompt Adapter request.

246
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248

    def __init__(
249
250
251
252
253
254
255
            self,
            seq_id: int,
            inputs: "LLMInputs",
            block_size: int,
            eos_token_id: Optional[int] = None,
            lora_request: Optional[LoRARequest] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
    ) -> None:
        self.seq_id = seq_id
258
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
259
        self.block_size = block_size
260
        self.eos_token_id = eos_token_id
261
        self.lora_request = lora_request
262
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
263

264
        self.data = SequenceData(self.prompt_token_ids)
265
        self.output_logprobs: SampleLogprobs = []
266
        self.output_text = ""
267

268
        self.status = SequenceStatus.WAITING
269
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
270

271
272
273
274
275
276
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

277
278
279
280
    @property
    def n_blocks(self) -> int:
        return math.ceil(self.get_len() / self.block_size)

281
282
    @property
    def prompt(self) -> Optional[str]:
283
        return self.inputs.get("prompt")
284
285
286
287
288
289

    @property
    def prompt_token_ids(self) -> List[int]:
        return self.inputs["prompt_token_ids"]

    @property
290
291
    def multi_modal_data(self) -> "MultiModalDataDict":
        return self.inputs.get("multi_modal_data") or {}
292

293
294
295
296
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

297
298
299
300
301
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

302
303
304
305
306
307
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

308
    def hash_of_block(self, logical_idx: int) -> int:
309
310
        # TODO This can produce incorrect hash when block size > prompt size

311
        # Compute the number of tokens in the sequence
312
313
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
314
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
315
316
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
317
318
319
320

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

321
322
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
323
        self.data.reset_state_for_recompute()
324

325
326
327
    def append_token_id(
        self,
        token_id: int,
328
        logprobs: Dict[int, Logprob],
329
    ) -> None:
330
331
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
332
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
333

Woosuk Kwon's avatar
Woosuk Kwon committed
334
    def get_len(self) -> int:
335
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
336

337
338
339
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

340
341
342
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
343
    def get_token_ids(self) -> List[int]:
344
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
345

346
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
347
348
        return self.data.get_prompt_token_ids()

349
    def get_last_token_id(self) -> int:
350
        return self.data.get_last_token_id()
351

352
353
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
354
355
356
357

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

358
    def get_beam_search_score(self,
359
                              length_penalty: float = 1.0,
360
361
362
363
364
365
366
367
368
369
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
370
            # NOTE: HF implementation does not count the EOS token
371
372
373
374
375
376
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

377
378
379
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

380
381
382
383
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
384

385
386
387
388
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
389
390
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
391
392
393
394
395
396
397
398
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
399
    def __repr__(self) -> str:
400
401
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
402
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
403

Woosuk Kwon's avatar
Woosuk Kwon committed
404

Nick Hill's avatar
Nick Hill committed
405
406
407
408
409
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
410
    generator: Optional = None  # type: ignore
Nick Hill's avatar
Nick Hill committed
411
412


Woosuk Kwon's avatar
Woosuk Kwon committed
413
class SequenceGroup:
414
415
416
417
418
419
420
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
421
        lora_request: LoRA request.
422
423
424
425
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
426
427
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
428
        trace_headers: OpenTelemetry trace headers.
429
        prompt_adapter_request: Prompt Adapter request.
430
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
431
432
433

    def __init__(
        self,
434
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
435
        seqs: List[Sequence],
436
        arrival_time: float,
437
        sampling_params: Optional[SamplingParams] = None,
438
        lora_request: Optional[LoRARequest] = None,
439
440
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
441
        encoder_seq: Optional[Sequence] = None,
442
        trace_headers: Optional[Mapping[str, str]] = None,
443
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
444
    ) -> None:
445
        self.request_id = request_id
446
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
447
        self.sampling_params = sampling_params
448
449
450
451
452
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
453
        self.lora_request = lora_request
454
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
455
        self.state = SequenceGroupState()
456
457
        self.embeddings = embeddings
        self.pooling_params = pooling_params
458
        self.prompt_adapter_request = prompt_adapter_request
459
        self.encoder_seq = encoder_seq
460
        self.trace_headers = trace_headers
461
        self._first_seq = next(iter(self.seqs_dict.values()))
462
463

    @property
464
    def prompt(self) -> Optional[str]:
465
466
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
467
        return self._first_seq.prompt
468
469
470
471
472

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
473
        return self._first_seq.prompt_token_ids
474
475

    @property
476
    def multi_modal_data(self) -> "MultiModalDataDict":
477
478
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
479
        return self._first_seq.multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
480

481
482
483
484
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

485
486
487
488
489
490
491
492
493
494
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

495
496
497
498
499
500
501
502
503
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
504
505
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
506
507
        return latency

508
509
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
510
511
512
513
514
515
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
                and self.get_seqs()[0].get_output_len() == 1):
516
517
518
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
519
520
        """Sets the first scheduled time and time in queue for Request
        level timings."""
521
522
523
524
525
526
527
528
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

529
530
531
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
532
        if self.sampling_params and self.sampling_params.use_beam_search:
533
534
535
536
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
537
538
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
539
540
541
542
543
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
544
545
            # that are not finished yet.
            return self.num_unfinished_seqs()
546

547
548
549
550
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
551
552
553
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
554

555
556
557
558
559
560
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

561
562
563
564
565
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

566
567
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
568

569
570
571
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        for seq in self.seqs_dict.values():
572
573
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
574
575

    def get_num_uncomputed_tokens(self) -> int:
576
577
        num_uncomputed_tokens = 0
        for seq in self.get_seqs():
578
579
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
580
        return num_uncomputed_tokens
581

582
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
583
584
585
586
587
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
            return len(self.seqs_dict)

588
        return len(self.get_seqs(status))
589

590
591
592
593
594
595
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

596
    def find(self, seq_id: int) -> Sequence:
597
598
599
600
601
602
603
604
605
606
607
608
609
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
610

Woosuk Kwon's avatar
Woosuk Kwon committed
611
    def is_finished(self) -> bool:
612
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
613

614
    def is_prefill(self) -> bool:
615
        # Every sequence should be in the same stage.
616
617
        return self.get_seqs()[0].is_prefill()

Woosuk Kwon's avatar
Woosuk Kwon committed
618
    def __repr__(self) -> str:
619
620
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
621
                f"num_seqs={len(self.seqs_dict)})")
622
623


624
class SequenceGroupMetadata:
625
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
626
627
628
629
630
631
632
633

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
634
635
636
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
637
638
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
639
        lora_request: LoRA request.
640
641
642
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
        state: Internal state tied to this sequence group.
643
        multi_modal_data: Multi modal data.
644
645
646
647
648
649
650
651
652
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
653
        prompt_adapter_request: Prompt Adapter request.
654
    """
655
656
657

    def __init__(
        self,
658
        request_id: str,
659
        is_prompt: bool,
660
        seq_data: Dict[int, SequenceData],
661
        sampling_params: SamplingParams,
662
        block_tables: Dict[int, List[int]],
663
        do_sample: bool = True,
664
        pooling_params: Optional[PoolingParams] = None,
665
        token_chunk_size: Optional[int] = None,
666
        lora_request: Optional[LoRARequest] = None,
667
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
668
        state: Optional[SequenceGroupState] = None,
669
        multi_modal_data: Optional["MultiModalDataDict"] = None,
670
671
        encoder_seq_data: Optional[SequenceData] = None,
        cross_block_table: Optional[List[int]] = None,
672
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
673
    ) -> None:
674
        self.request_id = request_id
675
        self.is_prompt = is_prompt
676
        self.seq_data = seq_data
677
678
        self.sampling_params = sampling_params
        self.block_tables = block_tables
679
        self.pooling_params = pooling_params
680
        self.lora_request = lora_request
681
        self.prompt_adapter_request = prompt_adapter_request
682
        self.computed_block_nums = computed_block_nums
683
        self.multi_modal_data = multi_modal_data
Nick Hill's avatar
Nick Hill committed
684
        self.state = SequenceGroupState() if state is None else state
685
686
        self.encoder_seq_data = encoder_seq_data
        self.cross_block_table = cross_block_table
687
        self._token_chunk_size = token_chunk_size
688
        self.do_sample = do_sample
689

690
691
692
693
694
695
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

696
697
698
699
700
        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
701

702
703
704
705
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

706
    @property
707
708
709
710
711
712
713
714
715
716
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

    @property
717
    def token_chunk_size(self) -> int:
718
        """Return the number of tokens to be processed (chunk size)."""
719
        assert self._token_chunk_size is not None
720
721
        return self._token_chunk_size

722

Zhuohan Li's avatar
Zhuohan Li committed
723
class SequenceOutput:
724
725
726
727
728
729
730
731
732
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
733
734
735
736
737

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
738
        logprobs: Dict[int, Logprob],
739
740
741
742
743
744
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
745
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
746
747
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
748

749
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
750
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
751
            raise NotImplementedError()
752
753
754
755
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
756
757


758
759
760
761
762
763
764
765
766
767
768
769
770
771
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
772
773
774

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
775
        samples: List[SequenceOutput],
776
777
778
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
779
        # Prompt logprob for each prompt query token.
780
781
782
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
783
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
784
785
                f"prompt_logprobs={self.prompt_logprobs})")

786
    def __eq__(self, other: object) -> bool:
787
        if not isinstance(other, CompletionSequenceGroupOutput):
788
789
790
791
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

792

793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
@dataclass
class IntermediateTensors:
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


840
841
842
843
844
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

845
    This data structure implements methods, so it can be used like a list, but
846
847
848
    also has optional fields for device tensors.
    """

849
    outputs: List[CompletionSequenceGroupOutput]
850
851

    # On-device tensor containing probabilities of each token.
852
    sampled_token_probs: Optional[torch.Tensor] = None
853

854
855
856
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

857
    # On-device tensor containing the sampled token ids.
858
    sampled_token_ids: Optional[torch.Tensor] = None
859
860
861
862

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

863
864
865
    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

866
867
868
869
870
871
872
873
874
875
876
877
    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
878
879
880
881
882
883
884
885
886
887
888
889
890

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
891
892


893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


914
915
916
917
918
919
920
921
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
class HiddenStates:
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""

    def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
                 hidden_states: torch.Tensor):
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
        self.hidden_states: torch.Tensor = hidden_states

    def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
               hidden_states: torch.Tensor) -> None:
        """Update hidden states from target model invocation."""
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        """Prune to provided list of sequence ids."""
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self.seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
            self.hidden_states = self.hidden_states[index]
            self.seq_ids = seq_ids


969
970
@dataclass
class ExecuteModelRequest:
971
972
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
973
974
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
975
976
977
978
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
979
980
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
981
982
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
983
984
985
986
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
987
988
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
989
990
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
991
992
    # Finished request ids since last step.
    finished_requests_ids: List[str] = field(default_factory=list)
993
994
995
996
997
998
999
1000
1001
1002

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1003
            virtual_engine=self.virtual_engine,
1004
1005
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1006
            previous_hidden_states=self.previous_hidden_states,
1007
            num_steps=self.num_steps,
Mor Zusman's avatar
Mor Zusman committed
1008
            finished_requests_ids=self.finished_requests_ids)