sequence.py 37.3 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
import math
5
from abc import ABC, abstractmethod
6
from collections import defaultdict
7
from dataclasses import dataclass, field
8
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
import torch

12
from vllm.lora.request import LoRARequest
13
from vllm.pooling_params import PoolingParams
14
from vllm.prompt_adapter.request import PromptAdapterRequest
15
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17
if TYPE_CHECKING:
18
    from vllm.inputs import LLMInputs
19
    from vllm.multimodal import MultiModalDataDict
20
21
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

22
23
24

@dataclass
class Logprob:
25
26
27
28
29
30
31
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
32
    logprob: float
33
    rank: Optional[int] = None
34
35
36
    decoded_token: Optional[str] = None


37
38
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
39
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
40
# {token_id -> logprob} for each sequence group.
41
SampleLogprobs = List[Dict[int, Logprob]]
42

Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
class SequenceStatus(enum.IntEnum):
45
    """Status of a sequence."""
46
47
48
49
50
51
52
53
54
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
55
56
57

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
58
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
59
60
61
62
63
64
65

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
66
67
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
68
        elif status == SequenceStatus.FINISHED_IGNORED:
69
70
71
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
72
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
73
74
75
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
76

77

78
79
80
81
82
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


83
84
85
86
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

87
    Attributes:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


102
class SequenceData:
103
104
105
106
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
107
108
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
109
110
111
112
113
114

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
115
116
117
118

    def __init__(
        self,
        prompt_token_ids: List[int],
119
        output_token_ids: Optional[List[int]] = None,
120
    ) -> None:
121
122
123
124
        self._prompt_token_ids: List[int] = list(prompt_token_ids)
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
        self._output_token_ids: List[int] = (
            list(output_token_ids) if output_token_ids is not None else [])
125

126
        self.cumulative_logprob = 0.0
127
128
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
129
        self._stage: SequenceStage = SequenceStage.PREFILL
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        self._update_cached_all_tokens()

    def _update_cached_all_tokens(self):
        self._cached_all_token_ids: List[int] = (self._prompt_token_ids +
                                                 self._output_token_ids)

    @property
    def prompt_token_ids(self) -> Tuple[int, ...]:
        return self._prompt_token_ids_tuple

    @prompt_token_ids.setter
    def prompt_token_ids(self, new_prompt_token_ids) -> None:
        self._prompt_token_ids = list(new_prompt_token_ids)
        self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
        self._update_cached_all_tokens()

    @property
    def output_token_ids(self) -> Tuple[int, ...]:
        return tuple(self._output_token_ids)

    @output_token_ids.setter
    def output_token_ids(self, new_output_token_ids) -> None:
        self._output_token_ids = list(new_output_token_ids)
        self._update_cached_all_tokens()

156
    def append_token_id(self, token_id: int, logprob: float) -> None:
157
158
        self._output_token_ids.append(token_id)
        self._cached_all_token_ids.append(token_id)
159
        self.cumulative_logprob += logprob
160
161

    def get_len(self) -> int:
162
        return len(self._output_token_ids) + len(self._prompt_token_ids)
163

164
    def get_prompt_len(self) -> int:
165
        return len(self._prompt_token_ids)
166

167
    def get_output_len(self) -> int:
168
        return len(self._output_token_ids)
169

170
    def get_token_ids(self) -> List[int]:
171
        return self._cached_all_token_ids
172

173
174
175
176
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
177
        prompt_length = self.get_prompt_len()
178
179
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
180
                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
181
182
183
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

184
185
186
187
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

188
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
189
190
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
191
192
193
194
195
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
196

197
    def reset_state_for_recompute(self) -> None:
198
199
200
201
202
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
203
        self._stage = SequenceStage.PREFILL
204
205

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
206
        """Return the number of prefill tokens that are not computed."""
207
208
209
210
211
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

212
    def get_last_token_id(self) -> int:
213
214
215
        if not self._output_token_ids:
            return self._prompt_token_ids[-1]
        return self._output_token_ids[-1]
216

217
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
218
219
        return self.prompt_token_ids

220
    def get_output_token_ids(self) -> Tuple[int, ...]:
221
222
        return self.output_token_ids

223
224
225
226
    @property
    def stage(self) -> SequenceStage:
        return self._stage

227
228
    def __repr__(self) -> str:
        return (f"SequenceData("
229
230
                f"prompt_token_ids={self._prompt_token_ids}, "
                f"output_token_ids={self._output_token_ids}, "
231
                f"cumulative_logprob={self.cumulative_logprob})")
232
233


Woosuk Kwon's avatar
Woosuk Kwon committed
234
class Sequence:
235
236
237
238
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
239
        inputs: The inputs of the sequence.
240
241
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
242
        lora_request: LoRA request.
243
244
        prompt_adapter_request: Prompt Adapter request.

245
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247

    def __init__(
248
249
250
251
252
253
254
            self,
            seq_id: int,
            inputs: "LLMInputs",
            block_size: int,
            eos_token_id: Optional[int] = None,
            lora_request: Optional[LoRARequest] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
    ) -> None:
        self.seq_id = seq_id
257
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
258
        self.block_size = block_size
259
        self.eos_token_id = eos_token_id
260
        self.lora_request = lora_request
261
        self.prompt_adapter_request = prompt_adapter_request
Woosuk Kwon's avatar
Woosuk Kwon committed
262

263
        self.data = SequenceData(self.prompt_token_ids)
264
        self.output_logprobs: SampleLogprobs = []
265
        self.output_text = ""
266

267
        self.status = SequenceStatus.WAITING
268
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
269

270
271
272
273
274
275
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

276
277
278
279
    @property
    def n_blocks(self) -> int:
        return math.ceil(self.get_len() / self.block_size)

280
281
    @property
    def prompt(self) -> Optional[str]:
282
        return self.inputs.get("prompt")
283
284
285
286
287
288

    @property
    def prompt_token_ids(self) -> List[int]:
        return self.inputs["prompt_token_ids"]

    @property
289
290
    def multi_modal_data(self) -> "MultiModalDataDict":
        return self.inputs.get("multi_modal_data") or {}
291

292
293
294
295
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

296
297
298
299
300
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

301
302
303
304
305
306
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

307
    def hash_of_block(self, logical_idx: int) -> int:
308
309
        # TODO This can produce incorrect hash when block size > prompt size

310
        # Compute the number of tokens in the sequence
311
312
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
313
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
314
315
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
316
317
318
319

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

320
321
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
322
        self.data.reset_state_for_recompute()
323

324
325
326
    def append_token_id(
        self,
        token_id: int,
327
        logprobs: Dict[int, Logprob],
328
    ) -> None:
329
330
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
331
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
332

Woosuk Kwon's avatar
Woosuk Kwon committed
333
    def get_len(self) -> int:
334
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
335

336
337
338
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

339
340
341
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
342
    def get_token_ids(self) -> List[int]:
343
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
344

345
    def get_prompt_token_ids(self) -> Tuple[int, ...]:
346
347
        return self.data.get_prompt_token_ids()

348
    def get_last_token_id(self) -> int:
349
        return self.data.get_last_token_id()
350

351
352
    def get_output_token_ids(self) -> Tuple[int, ...]:
        return self.data.get_output_token_ids()
353
354
355
356

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

357
    def get_beam_search_score(self,
358
                              length_penalty: float = 1.0,
359
360
361
362
363
364
365
366
367
368
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
369
            # NOTE: HF implementation does not count the EOS token
370
371
372
373
374
375
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

376
377
378
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

379
380
381
382
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
383

384
385
386
387
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
388
389
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
390
391
392
393
394
395
396
397
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
398
    def __repr__(self) -> str:
399
400
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
401
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
402

Woosuk Kwon's avatar
Woosuk Kwon committed
403

Nick Hill's avatar
Nick Hill committed
404
405
406
407
408
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
409
    generator: Optional = None  # type: ignore
Nick Hill's avatar
Nick Hill committed
410
411


Woosuk Kwon's avatar
Woosuk Kwon committed
412
class SequenceGroup:
413
414
415
416
417
418
419
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
420
        lora_request: LoRA request.
421
422
423
424
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
425
426
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
427
        trace_headers: OpenTelemetry trace headers.
428
        prompt_adapter_request: Prompt Adapter request.
429
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
430
431
432

    def __init__(
        self,
433
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
434
        seqs: List[Sequence],
435
        arrival_time: float,
436
        sampling_params: Optional[SamplingParams] = None,
437
        lora_request: Optional[LoRARequest] = None,
438
439
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
440
        encoder_seq: Optional[Sequence] = None,
441
        trace_headers: Optional[Dict[str, str]] = None,
442
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
443
    ) -> None:
444
        self.request_id = request_id
445
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
446
        self.sampling_params = sampling_params
447
448
449
450
451
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
452
        self.lora_request = lora_request
453
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
454
        self.state = SequenceGroupState()
455
456
        self.embeddings = embeddings
        self.pooling_params = pooling_params
457
        self.prompt_adapter_request = prompt_adapter_request
458
        self.encoder_seq = encoder_seq
459
        self.trace_headers = trace_headers
460
461

    @property
462
    def prompt(self) -> Optional[str]:
463
464
465
466
467
468
469
470
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
471
472
473
        return next(iter(self.seqs_dict.values())).prompt_token_ids

    @property
474
    def multi_modal_data(self) -> "MultiModalDataDict":
475
476
477
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
478

479
480
481
482
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

483
484
485
486
487
488
489
490
491
492
    @property
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                         if self.prompt_adapter_request else 0

493
494
495
496
497
498
499
500
501
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
502
503
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
504
505
        return latency

506
507
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
508
509
510
511
512
513
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
                and self.get_seqs()[0].get_output_len() == 1):
514
515
516
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
517
518
        """Sets the first scheduled time and time in queue for Request
        level timings."""
519
520
521
522
523
524
525
526
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

527
528
529
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
530
        if self.sampling_params and self.sampling_params.use_beam_search:
531
532
533
534
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
535
536
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
537
538
539
540
541
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
542
543
            # that are not finished yet.
            return self.num_unfinished_seqs()
544

545
546
547
548
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
549
550
551
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
552

553
554
555
556
557
558
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

559
560
561
562
563
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

564
565
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
566

567
568
569
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        for seq in self.seqs_dict.values():
570
571
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
572
573

    def get_num_uncomputed_tokens(self) -> int:
574
575
        num_uncomputed_tokens = 0
        for seq in self.get_seqs():
576
577
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
578
        return num_uncomputed_tokens
579

580
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
581
582
583
584
585
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
            return len(self.seqs_dict)

586
        return len(self.get_seqs(status))
587

588
589
590
591
592
593
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

594
    def find(self, seq_id: int) -> Sequence:
595
596
597
598
599
600
601
602
603
604
605
606
607
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
608

Woosuk Kwon's avatar
Woosuk Kwon committed
609
    def is_finished(self) -> bool:
610
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
611

612
    def is_prefill(self) -> bool:
613
        # Every sequence should be in the same stage.
614
615
        return self.get_seqs()[0].is_prefill()

Woosuk Kwon's avatar
Woosuk Kwon committed
616
    def __repr__(self) -> str:
617
618
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
619
                f"num_seqs={len(self.seqs_dict)})")
620
621


622
class SequenceGroupMetadata:
623
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
624
625
626
627
628
629
630
631

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
632
633
634
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
635
636
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
637
        lora_request: LoRA request.
638
639
640
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
        state: Internal state tied to this sequence group.
641
        multi_modal_data: Multi modal data.
642
643
644
645
646
647
648
649
650
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
651
        prompt_adapter_request: Prompt Adapter request.
652
    """
653
654
655

    def __init__(
        self,
656
        request_id: str,
657
        is_prompt: bool,
658
        seq_data: Dict[int, SequenceData],
659
        sampling_params: SamplingParams,
660
        block_tables: Dict[int, List[int]],
661
        do_sample: bool = True,
662
        pooling_params: Optional[PoolingParams] = None,
663
        token_chunk_size: Optional[int] = None,
664
        lora_request: Optional[LoRARequest] = None,
665
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
666
        state: Optional[SequenceGroupState] = None,
667
        multi_modal_data: Optional["MultiModalDataDict"] = None,
668
669
        encoder_seq_data: Optional[SequenceData] = None,
        cross_block_table: Optional[List[int]] = None,
670
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
671
    ) -> None:
672
        self.request_id = request_id
673
        self.is_prompt = is_prompt
674
        self.seq_data = seq_data
675
676
        self.sampling_params = sampling_params
        self.block_tables = block_tables
677
        self.pooling_params = pooling_params
678
        self.lora_request = lora_request
679
        self.prompt_adapter_request = prompt_adapter_request
680
        self.computed_block_nums = computed_block_nums
681
        self.multi_modal_data = multi_modal_data
Nick Hill's avatar
Nick Hill committed
682
        self.state = SequenceGroupState() if state is None else state
683
684
        self.encoder_seq_data = encoder_seq_data
        self.cross_block_table = cross_block_table
685
        self._token_chunk_size = token_chunk_size
686
        self.do_sample = do_sample
687

688
689
690
691
692
693
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

694
695
696
697
698
        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
699

700
701
702
703
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

704
    @property
705
706
707
708
709
710
711
712
713
714
    def prompt_adapter_id(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_id \
                        if self.prompt_adapter_request else 0

    @property
    def prompt_adapter_num_virtual_tokens(self) -> int:
        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
                        if self.prompt_adapter_request else 0

    @property
715
    def token_chunk_size(self) -> int:
716
        """Return the number of tokens to be processed (chunk size)."""
717
        assert self._token_chunk_size is not None
718
719
        return self._token_chunk_size

720

Zhuohan Li's avatar
Zhuohan Li committed
721
class SequenceOutput:
722
723
724
725
726
727
728
729
730
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
731
732
733
734
735

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
736
        logprobs: Dict[int, Logprob],
737
738
739
740
741
742
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
743
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
744
745
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
746

747
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
748
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
749
            raise NotImplementedError()
750
751
752
753
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
754
755


756
757
758
759
760
761
762
763
764
765
766
767
768
769
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
770
771
772

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
773
        samples: List[SequenceOutput],
774
775
776
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
777
        # Prompt logprob for each prompt query token.
778
779
780
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
781
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
782
783
                f"prompt_logprobs={self.prompt_logprobs})")

784
    def __eq__(self, other: object) -> bool:
785
        if not isinstance(other, CompletionSequenceGroupOutput):
786
787
788
789
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

790

791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
@dataclass
class IntermediateTensors:
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
    """

    tensors: Dict[str, torch.Tensor]

    def __getitem__(self, key: Union[str, slice]):
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

    def __setitem__(self, key: str, value):
        self.tensors[key] = value

    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
        return isinstance(other, self.__class__) and self

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"


838
839
840
841
842
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

843
    This data structure implements methods, so it can be used like a list, but
844
845
846
    also has optional fields for device tensors.
    """

847
    outputs: List[CompletionSequenceGroupOutput]
848
849

    # On-device tensor containing probabilities of each token.
850
    sampled_token_probs: Optional[torch.Tensor] = None
851

852
853
854
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

855
    # On-device tensor containing the sampled token ids.
856
    sampled_token_ids: Optional[torch.Tensor] = None
857
858
859
860

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

861
862
863
    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

864
865
866
867
868
869
870
871
872
873
874
875
    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
876
877
878
879
880
881
882
883
884
885
886
887
888

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
889
890


891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


912
913
914
915
916
917
918
919
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
def get_all_seq_ids_and_request_ids(
    seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    seq_ids: List[int] = []
    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
    for sg in seq_group_metadata_list:
        for seq_id in sg.seq_data:
            seq_ids.append(seq_id)
            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
    return seq_ids, request_id_seq_ids_mapping


935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
class HiddenStates:
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""

    def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
                 hidden_states: torch.Tensor):
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
        self.hidden_states: torch.Tensor = hidden_states

    def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
               hidden_states: torch.Tensor) -> None:
        """Update hidden states from target model invocation."""
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        """Prune to provided list of sequence ids."""
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self.seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
            self.hidden_states = self.hidden_states[index]
            self.seq_ids = seq_ids


967
968
@dataclass
class ExecuteModelRequest:
969
970
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
971
972
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
973
974
975
976
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
977
978
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
979
980
    # Virtual engine ID for pipeline parallel.
    virtual_engine: int = 0
981
982
983
984
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
985
986
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
987
988
    # The number of forward steps to run.
    num_steps: int = 1
Mor Zusman's avatar
Mor Zusman committed
989
990
    # Finished request ids since last step.
    finished_requests_ids: List[str] = field(default_factory=list)
991
992
993
994
995
996
997
998
999
1000

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
1001
            virtual_engine=self.virtual_engine,
1002
1003
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
1004
            previous_hidden_states=self.previous_hidden_states,
1005
            num_steps=self.num_steps,
Mor Zusman's avatar
Mor Zusman committed
1006
            finished_requests_ids=self.finished_requests_ids)