sequence.py 37.5 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
import math
5
from abc import ABC, abstractmethod
6
from array import array
7
from collections import defaultdict
8
from dataclasses import dataclass, field
9
10
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
                    Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
13
import torch

14
from vllm.lora.request import LoRARequest
15
from vllm.pooling_params import PoolingParams
16
from vllm.prompt_adapter.request import PromptAdapterRequest
17
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
if TYPE_CHECKING:
20
    from vllm.inputs import LLMInputs
21
    from vllm.multimodal import MultiModalDataDict
22
23
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

24
25
26

@dataclass
class Logprob:
27
28
29
30
31
32
33
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
34
    logprob: float
35
    rank: Optional[int] = None
36
37
38
    decoded_token: Optional[str] = None


39
40
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
41
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
42
# {token_id -> logprob} for each sequence group.
43
SampleLogprobs = List[Dict[int, Logprob]]
44

Woosuk Kwon's avatar
Woosuk Kwon committed
45

46
class SequenceStatus(enum.IntEnum):
47
    """Status of a sequence."""
48
49
50
51
52
53
54
55
56
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
57
58
59

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
60
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
61
62
63
64
65
66
67

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
68
69
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
70
        elif status == SequenceStatus.FINISHED_IGNORED:
71
72
73
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
74
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
75
76
77
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
78

79

80
81
82
83
84
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


85
86
87
88
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

89
    Attributes:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


104
class SequenceData:
105
106
107
108
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
109
110
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
111
112
113
114
115
116

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
117
118
119
120

    def __init__(
        self,
        prompt_token_ids: List[int],
121
        output_token_ids: Optional[List[int]] = None,
122
    ) -> None:
123
        self._prompt_token_ids = array('l', prompt_token_ids)
124
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
125
126
        self._output_token_ids = array(
            'l', output_token_ids if output_token_ids is not None else [])
127

128
        self.cumulative_logprob = 0.0
129
130
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
131
        self._stage: SequenceStage = SequenceStage.PREFILL
132

133
134
135
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
136
137
        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
                                                     self._output_token_ids)
138
139
140
141
142
143
144

    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
145
        self._prompt_token_ids = array('l', new_prompt_token_ids)
146
147
148
        self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
        self._update_cached_all_tokens()

149
150
151
152
    @property
    def prompt_token_ids_array(self) -> array:
        return self._prompt_token_ids

153
154
155
156
157
158
    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
    def output_token_ids(self, new_output_token_ids) -> None:
159
        self._output_token_ids = array('l', new_output_token_ids)
160
161
        self._update_cached_all_tokens()

162
163
164
165
    @property
    def output_token_ids_array(self) -> array:
        return self._output_token_ids

166
    def append_token_id(self, token_id: int, logprob: float) -> None:
167
168
        self._output_token_ids.append(token_id)
        self._cached_all_token_ids.append(token_id)
169
        self.cumulative_logprob += logprob
170
171

    def get_len(self) -> int:
172
        return len(self._output_token_ids) + len(self._prompt_token_ids)
173

174
    def get_prompt_len(self) -> int:
175
        return len(self._prompt_token_ids)
176

177
    def get_output_len(self) -> int:
178
        return len(self._output_token_ids)
179

180
    def get_token_ids(self) -> List[int]:
181
        return self._cached_all_token_ids
182

183
184
185
186
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
187
        prompt_length = self.get_prompt_len()
188
189
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
190
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
191
192
193
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

194
195
196
197
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

198
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
199
200
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
201
202
203
204
205
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
206

207
    def reset_state_for_recompute(self) -> None:
208
209
210
211
212
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
213
        self._stage = SequenceStage.PREFILL
214
215

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
216
        """Return the number of prefill tokens that are not computed."""
217
218
219
220
221
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

222
    def get_last_token_id(self) -> int:
223
224
225
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
226

227
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
228
229
        return self.prompt_token_ids

230
    def get_output_token_ids(self) -> Tuple[int, ...]:
231
232
        return self.output_token_ids

233
234
235
236
    @property
    def stage(self) -> SequenceStage:
        return self._stage

237
238
    def __repr__(self) -> str:
        return (f"SequenceData("
239
240
                f"prompt_token_ids={self._prompt_token_ids}, "
                f"output_token_ids={self._output_token_ids}, "
241
                f"cumulative_logprob={self.cumulative_logprob})")
242
243


Woosuk Kwon's avatar
Woosuk Kwon committed
244
class Sequence:
245
246
247
248
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
249
        inputs: The inputs of the sequence.
250
251
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
252
        lora_request: LoRA request.
253
254
        prompt_adapter_request: Prompt Adapter request.

255
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257

    def __init__(
258
259
260
261
262
263
264
            self,
            seq_id: int,
            inputs: "LLMInputs",
            block_size: int,
            eos_token_id: Optional[int] = None,
            lora_request: Optional[LoRARequest] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
    ) -> None:
        self.seq_id = seq_id
267
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
268
        self.block_size = block_size
269
        self.eos_token_id = eos_token_id
270
        self.lora_request = lora_request
271
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
272

273
        self.data = SequenceData(self.prompt_token_ids)
274
        self.output_logprobs: SampleLogprobs = []
275
        self.output_text = ""
276

277
        self.status = SequenceStatus.WAITING
278
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
279

280
281
282
283
284
285
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

286
287
288
289
    @property
    def n_blocks(self) -> int:
        return math.ceil(self.get_len() / self.block_size)

290
291
    @property
    def prompt(self) -> Optional[str]:
292
        return self.inputs.get("prompt")
293
294
295
296
297
298

    @property
    def prompt_token_ids(self) -> List[int]:
        return self.inputs["prompt_token_ids"]

    @property
299
300
    def multi_modal_data(self) -> "MultiModalDataDict":
        return self.inputs.get("multi_modal_data") or {}
301

302
303
304
305
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

306
307
308
309
310
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

311
312
313
314
315
316
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

317
    def hash_of_block(self, logical_idx: int) -> int:
318
319
        # TODO This can produce incorrect hash when block size > prompt size

320
        # Compute the number of tokens in the sequence
321
322
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
323
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
324
325
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
326
327
328
329

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

330
331
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
332
        self.data.reset_state_for_recompute()
333

334
335
336
    def append_token_id(
        self,
        token_id: int,
337
        logprobs: Dict[int, Logprob],
338
    ) -> None:
339
340
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
341
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
342

Woosuk Kwon's avatar
Woosuk Kwon committed
343
    def get_len(self) -> int:
344
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
345

346
347
348
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

349
350
351
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
352
    def get_token_ids(self) -> List[int]:
353
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
354

355
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
356
357
        return self.data.get_prompt_token_ids()

358
    def get_last_token_id(self) -> int:
359
        return self.data.get_last_token_id()
360

361
362
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
363
364
365
366

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

367
    def get_beam_search_score(self,
368
                              length_penalty: float = 1.0,
369
370
371
372
373
374
375
376
377
378
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
379
            # NOTE: HF implementation does not count the EOS token
380
381
382
383
384
385
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

386
387
388
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

389
390
391
392
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
393

394
395
396
397
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
398
399
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
400
401
402
403
404
405
406
407
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
408
    def __repr__(self) -> str:
409
410
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
411
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
412

Woosuk Kwon's avatar
Woosuk Kwon committed
413

Nick Hill's avatar
Nick Hill committed
414
415
416
417
418
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
419
    generator: Optional = None  # type: ignore
Nick Hill's avatar
Nick Hill committed
420
421


Woosuk Kwon's avatar
Woosuk Kwon committed
422
class SequenceGroup:
423
424
425
426
427
428
429
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
430
        lora_request: LoRA request.
431
432
433
434
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
435
436
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
437
        trace_headers: OpenTelemetry trace headers.
438
        prompt_adapter_request: Prompt Adapter request.
439
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
440
441
442

    def __init__(
        self,
443
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
444
        seqs: List[Sequence],
445
        arrival_time: float,
446
        sampling_params: Optional[SamplingParams] = None,
447
        lora_request: Optional[LoRARequest] = None,
448
449
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
450
        encoder_seq: Optional[Sequence] = None,
451
        trace_headers: Optional[Mapping[str, str]] = None,
452
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
453
    ) -> None:
454
        self.request_id = request_id
455
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
456
        self.sampling_params = sampling_params
457
458
459
460
461
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
462
        self.lora_request = lora_request
463
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
464
        self.state = SequenceGroupState()
465
466
        self.embeddings = embeddings
        self.pooling_params = pooling_params
467
        self.prompt_adapter_request = prompt_adapter_request
468
        self.encoder_seq = encoder_seq
469
        self.trace_headers = trace_headers
470
        self._first_seq = next(iter(self.seqs_dict.values()))
471
472

    @property
473
    def prompt(self) -> Optional[str]:
474
475
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
476
        return self._first_seq.prompt
477
478
479
480
481

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
482
        return self._first_seq.prompt_token_ids
483
484

    @property
485
    def multi_modal_data(self) -> "MultiModalDataDict":
486
487
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
488
        return self._first_seq.multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
489

490
491
492
493
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

494
495
496
497
498
499
500
501
502
503
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

504
505
506
507
508
509
510
511
512
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
513
514
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
515
516
        return latency

517
518
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
519
520
521
522
523
524
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
                and self.get_seqs()[0].get_output_len() == 1):
525
526
527
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
528
529
        """Sets the first scheduled time and time in queue for Request
        level timings."""
530
531
532
533
534
535
536
537
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

538
539
540
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
541
        if self.sampling_params and self.sampling_params.use_beam_search:
542
543
544
545
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
546
547
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
548
549
550
551
552
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
553
554
            # that are not finished yet.
            return self.num_unfinished_seqs()
555

556
557
558
559
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
560
561
562
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
563

564
565
566
567
568
569
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

570
571
572
573
574
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

575
576
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
577

578
579
580
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        for seq in self.seqs_dict.values():
581
582
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
583
584

    def get_num_uncomputed_tokens(self) -> int:
585
586
        num_uncomputed_tokens = 0
        for seq in self.get_seqs():
587
588
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
589
        return num_uncomputed_tokens
590

591
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
592
593
594
595
596
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
            return len(self.seqs_dict)

597
        return len(self.get_seqs(status))
598

599
600
601
602
603
604
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

605
    def find(self, seq_id: int) -> Sequence:
606
607
608
609
610
611
612
613
614
615
616
617
618
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
619

Woosuk Kwon's avatar
Woosuk Kwon committed
620
    def is_finished(self) -> bool:
621
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
622

623
    def is_prefill(self) -> bool:
624
        # Every sequence should be in the same stage.
625
626
        return self.get_seqs()[0].is_prefill()

Woosuk Kwon's avatar
Woosuk Kwon committed
627
    def __repr__(self) -> str:
628
629
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
630
                f"num_seqs={len(self.seqs_dict)})")
631
632


633
class SequenceGroupMetadata:
634
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
635
636
637
638
639
640
641
642

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
643
644
645
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
646
647
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
648
        lora_request: LoRA request.
649
650
651
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
        state: Internal state tied to this sequence group.
652
        multi_modal_data: Multi modal data.
653
654
655
656
657
658
659
660
661
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
662
        prompt_adapter_request: Prompt Adapter request.
663
    """
664
665
666

    def __init__(
        self,
667
        request_id: str,
668
        is_prompt: bool,
669
        seq_data: Dict[int, SequenceData],
670
        sampling_params: SamplingParams,
671
        block_tables: Dict[int, List[int]],
672
        do_sample: bool = True,
673
        pooling_params: Optional[PoolingParams] = None,
674
        token_chunk_size: Optional[int] = None,
675
        lora_request: Optional[LoRARequest] = None,
676
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
677
        state: Optional[SequenceGroupState] = None,
678
        multi_modal_data: Optional["MultiModalDataDict"] = None,
679
680
        encoder_seq_data: Optional[SequenceData] = None,
        cross_block_table: Optional[List[int]] = None,
681
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
682
    ) -> None:
683
        self.request_id = request_id
684
        self.is_prompt = is_prompt
685
        self.seq_data = seq_data
686
687
        self.sampling_params = sampling_params
        self.block_tables = block_tables
688
        self.pooling_params = pooling_params
689
        self.lora_request = lora_request
690
        self.prompt_adapter_request = prompt_adapter_request
691
        self.computed_block_nums = computed_block_nums
692
        self.multi_modal_data = multi_modal_data
Nick Hill's avatar
Nick Hill committed
693
        self.state = SequenceGroupState() if state is None else state
694
695
        self.encoder_seq_data = encoder_seq_data
        self.cross_block_table = cross_block_table
696
        self._token_chunk_size = token_chunk_size
697
        self.do_sample = do_sample
698

699
700
701
702
703
704
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

705
706
707
708
709
        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
710

711
712
713
714
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

715
    @property
716
717
718
719
720
721
722
723
724
725
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

    @property
726
    def token_chunk_size(self) -> int:
727
        """Return the number of tokens to be processed (chunk size)."""
728
        assert self._token_chunk_size is not None
729
730
        return self._token_chunk_size

731

Zhuohan Li's avatar
Zhuohan Li committed
732
class SequenceOutput:
733
734
735
736
737
738
739
740
741
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
742
743
744
745
746

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
747
        logprobs: Dict[int, Logprob],
748
749
750
751
752
753
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
754
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
755
756
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
757

758
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
759
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
760
            raise NotImplementedError()
761
762
763
764
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
765
766


767
768
769
770
771
772
773
774
775
776
777
778
779
780
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
781
782
783

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
784
        samples: List[SequenceOutput],
785
786
787
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
788
        # Prompt logprob for each prompt query token.
789
790
791
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
792
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
793
794
                f"prompt_logprobs={self.prompt_logprobs})")

795
    def __eq__(self, other: object) -> bool:
796
        if not isinstance(other, CompletionSequenceGroupOutput):
797
798
799
800
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

801

802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
@dataclass
class IntermediateTensors:
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


849
850
851
852
853
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

854
    This data structure implements methods, so it can be used like a list, but
855
856
857
    also has optional fields for device tensors.
    """

858
    outputs: List[CompletionSequenceGroupOutput]
859
860

    # On-device tensor containing probabilities of each token.
861
    sampled_token_probs: Optional[torch.Tensor] = None
862

863
864
865
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

866
    # On-device tensor containing the sampled token ids.
867
    sampled_token_ids: Optional[torch.Tensor] = None
868
869
870
871

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

872
873
874
    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

875
876
877
878
879
880
881
882
883
884
885
886
    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
887
888
889
890
891
892
893
894
895
896
897
898
899

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
900
901


902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


923
924
925
926
927
928
929
930
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
class HiddenStates:
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""

    def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
                 hidden_states: torch.Tensor):
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
        self.hidden_states: torch.Tensor = hidden_states

    def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
               hidden_states: torch.Tensor) -> None:
        """Update hidden states from target model invocation."""
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        """Prune to provided list of sequence ids."""
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self.seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
            self.hidden_states = self.hidden_states[index]
            self.seq_ids = seq_ids


978
979
@dataclass
class ExecuteModelRequest:
980
981
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
982
983
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
984
985
986
987
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
988
989
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
990
991
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
992
993
994
995
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
996
997
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
998
999
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
1000
1001
    # Finished request ids since last step.
    finished_requests_ids: List[str] = field(default_factory=list)
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1012
            virtual_engine=self.virtual_engine,
1013
1014
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1015
            previous_hidden_states=self.previous_hidden_states,
1016
            num_steps=self.num_steps,
Mor Zusman's avatar
Mor Zusman committed
1017
            finished_requests_ids=self.finished_requests_ids)