sequence.py 33 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
import math
5
from abc import ABC, abstractmethod
6
from dataclasses import dataclass, field
7
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.pooling_params import PoolingParams
13
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
if TYPE_CHECKING:
16
    from vllm.inputs import LLMInputs
17
    from vllm.multimodal import MultiModalData
18
19
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

20
21
22

@dataclass
class Logprob:
23
24
25
26
27
28
29
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
30
    logprob: float
31
    rank: Optional[int] = None
32
33
34
    decoded_token: Optional[str] = None


35
36
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
37
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
38
# {token_id -> logprob} for each sequence group.
39
SampleLogprobs = List[Dict[int, Logprob]]
40

Woosuk Kwon's avatar
Woosuk Kwon committed
41

42
class SequenceStatus(enum.IntEnum):
43
    """Status of a sequence."""
44
45
46
47
48
49
50
51
52
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6
Zhuohan Li's avatar
Zhuohan Li committed
53
54
55

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
56
        return status > SequenceStatus.SWAPPED
Zhuohan Li's avatar
Zhuohan Li committed
57
58
59
60
61
62
63

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
64
65
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
66
        elif status == SequenceStatus.FINISHED_IGNORED:
67
68
69
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
70
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
71
72
73
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
74

75

76
77
78
79
80
class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()


81
82
83
84
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

85
    Attributes:
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


100
class SequenceData:
101
102
103
104
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
105
106
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
107
108
109
110
111
112

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
113
114
115
116

    def __init__(
        self,
        prompt_token_ids: List[int],
117
        output_token_ids: Optional[List[int]] = None,
118
    ) -> None:
119
120
121
        if output_token_ids is None:
            output_token_ids = []

122
        self.prompt_token_ids = prompt_token_ids
123
        self._prompt_token_ids_tuple = tuple(prompt_token_ids)
124
        self.output_token_ids = output_token_ids
125
        self.cumulative_logprob = 0.0
126
127
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
128
        self._stage: SequenceStage = SequenceStage.PREFILL
129

130
    def append_token_id(self, token_id: int, logprob: float) -> None:
131
132
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
133
134
135
136

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

137
138
139
    def get_prompt_len(self) -> int:
        return len(self.prompt_token_ids)

140
141
142
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

143
144
145
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

146
147
148
149
150
151
152
153
154
155
156
    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
        prompt_length = len(self.prompt_token_ids)
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
                    tuple(self.output_token_ids[:num_tokens - prompt_length]))
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

157
158
159
160
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

161
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
162
163
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
164
165
166
167
168
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
169

170
    def reset_state_for_recompute(self) -> None:
171
172
173
174
175
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
176
        self._stage = SequenceStage.PREFILL
177
178

    def get_num_uncomputed_tokens(self) -> int:
Uranus's avatar
Uranus committed
179
        """Return the number of prefill tokens that are not computed."""
180
181
182
183
184
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

185
186
187
188
189
    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

190
    def get_prompt_token_ids(self) -> List[int]:
191
192
        return self.prompt_token_ids

193
    def get_output_token_ids(self) -> List[int]:
194
195
        return self.output_token_ids

196
197
198
199
    @property
    def stage(self) -> SequenceStage:
        return self._stage

200
201
202
    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
203
204
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
205
206


Woosuk Kwon's avatar
Woosuk Kwon committed
207
class Sequence:
208
209
210
211
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
212
        inputs: The inputs of the sequence.
213
214
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
215
        lora_request: LoRA request.
216
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219
220

    def __init__(
        self,
        seq_id: int,
221
        inputs: "LLMInputs",
Woosuk Kwon's avatar
Woosuk Kwon committed
222
        block_size: int,
Cade Daniel's avatar
Cade Daniel committed
223
        eos_token_id: Optional[int] = None,
224
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
    ) -> None:
        self.seq_id = seq_id
227
        self.inputs = inputs
Woosuk Kwon's avatar
Woosuk Kwon committed
228
        self.block_size = block_size
229
        self.eos_token_id = eos_token_id
230
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
231

232
        self.data = SequenceData(self.prompt_token_ids)
233
        self.output_logprobs: SampleLogprobs = []
234
        self.output_text = ""
235

236
        self.status = SequenceStatus.WAITING
237
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
238

239
240
241
242
243
244
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

245
246
247
248
    @property
    def n_blocks(self) -> int:
        return math.ceil(self.get_len() / self.block_size)

249
250
    @property
    def prompt(self) -> Optional[str]:
251
        return self.inputs.get("prompt")
252
253
254
255
256
257
258

    @property
    def prompt_token_ids(self) -> List[int]:
        return self.inputs["prompt_token_ids"]

    @property
    def multi_modal_data(self) -> Optional["MultiModalData"]:
259
        return self.inputs.get("multi_modal_data")
260

261
262
263
264
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

265
266
267
268
269
270
    def get_output_text_to_return(self, buffer_length: int):
        # We return the full output text if the sequence is finished.
        truncate = buffer_length and not self.is_finished()
        return self.output_text[:-buffer_length] if truncate else (
            self.output_text)

271
    def hash_of_block(self, logical_idx: int) -> int:
272
273
        # TODO This can produce incorrect hash when block size > prompt size

274
        # Compute the number of tokens in the sequence
275
276
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
277
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
278
279
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))
280
281
282
283

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

284
285
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
286
        self.data.reset_state_for_recompute()
287

288
289
290
    def append_token_id(
        self,
        token_id: int,
291
        logprobs: Dict[int, Logprob],
292
    ) -> None:
293
294
        assert token_id in logprobs
        self.output_logprobs.append(logprobs)
295
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
296

Woosuk Kwon's avatar
Woosuk Kwon committed
297
    def get_len(self) -> int:
298
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
299

300
301
302
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

303
304
305
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
306
    def get_token_ids(self) -> List[int]:
307
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
308

309
310
311
    def get_prompt_token_ids(self) -> List[int]:
        return self.data.get_prompt_token_ids()

312
    def get_last_token_id(self) -> int:
313
        return self.data.get_last_token_id()
314

315
316
317
318
319
320
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

321
    def get_beam_search_score(self,
322
                              length_penalty: float = 1.0,
323
324
325
326
327
328
329
330
331
332
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
333
            # NOTE: HF implementation does not count the EOS token
334
335
336
337
338
339
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

340
341
342
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

343
344
345
346
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
347

348
349
350
351
    def get_num_new_tokens(self) -> int:
        """Get the number of new tokens to be computed.

        Returns:
Uranus's avatar
Uranus committed
352
353
            The new number of tokens to be computed. I.e., 1 for decode, or
            the remaining prompt size for prefill.
354
355
356
357
358
359
360
361
        """
        if self.data.stage == SequenceStage.DECODE:
            return 1
        return self.data.get_num_uncomputed_tokens()

    def is_prefill(self) -> bool:
        return self.data.stage == SequenceStage.PREFILL

Woosuk Kwon's avatar
Woosuk Kwon committed
362
    def __repr__(self) -> str:
363
364
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
365
                f"num_blocks={self.n_blocks}, ")
Woosuk Kwon's avatar
Woosuk Kwon committed
366

Woosuk Kwon's avatar
Woosuk Kwon committed
367

Nick Hill's avatar
Nick Hill committed
368
369
370
371
372
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
373
    generator: Optional = None  # type: ignore
Nick Hill's avatar
Nick Hill committed
374
375


Woosuk Kwon's avatar
Woosuk Kwon committed
376
class SequenceGroup:
377
378
379
380
381
382
383
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
384
        lora_request: LoRA request.
385
386
387
388
        embeddings: The embeddings vectors of the prompt of the sequence group
            for an embedding model.
        pooling_params: The pooling parameters used to generate the pooling
            for an embedding model.
389
390
        encoder_seq: Optional, the single encoder sequence. Should be None
                     unless you are working with an encoder/decoder model.
391
        trace_headers: OpenTelemetry trace headers.
392
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
395

    def __init__(
        self,
396
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
397
        seqs: List[Sequence],
398
        arrival_time: float,
399
        sampling_params: Optional[SamplingParams] = None,
400
        lora_request: Optional[LoRARequest] = None,
401
402
        embeddings: Optional[List[float]] = None,
        pooling_params: Optional[PoolingParams] = None,
403
        encoder_seq: Optional[Sequence] = None,
404
        trace_headers: Optional[Dict[str, str]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
405
    ) -> None:
406
        self.request_id = request_id
407
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
408
        self.sampling_params = sampling_params
409
410
411
412
413
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
414
        self.lora_request = lora_request
415
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
416
        self.state = SequenceGroupState()
417
418
        self.embeddings = embeddings
        self.pooling_params = pooling_params
419
        self.encoder_seq = encoder_seq
420
        self.trace_headers = trace_headers
421
422

    @property
423
    def prompt(self) -> Optional[str]:
424
425
426
427
428
429
430
431
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
432
433
434
        return next(iter(self.seqs_dict.values())).prompt_token_ids

    @property
435
    def multi_modal_data(self) -> Optional["MultiModalData"]:
436
437
438
        # All sequences in the group should have the same multi-modal data.
        # We use the multi-modal data of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).multi_modal_data
Woosuk Kwon's avatar
Woosuk Kwon committed
439

440
441
442
443
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

444
445
446
447
448
449
450
451
452
    def get_last_latency(self, now: float) -> Optional[float]:
        """Sets the last token time for Request level timings."""
        # If still in prefill phase, raise Error.
        if self.is_prefill():
            raise ValueError(
                "seq_group.get_last_latency() should not be called "
                "if the seq_group is in prefill phase.")

        # Otherwise return token latency.
453
454
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
455
456
        return latency

457
458
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
459
460
461
462
463
464
        # Note: in a case where a sequence_group is swapped and
        #   recomputed, the time between iterations is counted
        #   in TPOT, rather than recalculating TTFT (since from the )
        #   POV of the user, there is simply a long generation delay.
        if (self.metrics.first_token_time is None
                and self.get_seqs()[0].get_output_len() == 1):
465
466
467
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
468
469
        """Sets the first scheduled time and time in queue for Request
        level timings."""
470
471
472
473
474
475
476
477
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

478
479
480
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
481
        if self.sampling_params and self.sampling_params.use_beam_search:
482
483
484
485
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
486
487
            if (self.sampling_params
                    and self.sampling_params.best_of > self.num_seqs()):
488
489
490
491
492
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
493
494
            # that are not finished yet.
            return self.num_unfinished_seqs()
495

496
497
498
499
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
500
501
502
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
503

504
505
506
507
508
509
    def is_encoder_decoder(self) -> bool:
        return self.encoder_seq is not None

    def get_encoder_seq(self) -> Optional[Sequence]:
        return self.encoder_seq

510
511
512
513
514
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

515
516
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
517

518
519
520
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        for seq in self.seqs_dict.values():
521
522
            if not seq.is_finished():
                seq.data.update_num_computed_tokens(num_new_computed_tokens)
523
524

    def get_num_uncomputed_tokens(self) -> int:
525
526
        num_uncomputed_tokens = 0
        for seq in self.get_seqs():
527
528
            if not seq.is_finished():
                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
529
        return num_uncomputed_tokens
530

531
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
532
533
534
535
536
        # Optimization. We don't need to call get_seqs if we don't need to
        # filter by states.
        if status is None:
            return len(self.seqs_dict)

537
        return len(self.get_seqs(status))
538

539
540
541
542
543
544
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

545
    def find(self, seq_id: int) -> Sequence:
546
547
548
549
550
551
552
553
554
555
556
557
558
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
559

Woosuk Kwon's avatar
Woosuk Kwon committed
560
    def is_finished(self) -> bool:
561
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
562

563
    def is_prefill(self) -> bool:
564
        # Every sequence should be in the same stage.
565
566
        return self.get_seqs()[0].is_prefill()

Woosuk Kwon's avatar
Woosuk Kwon committed
567
    def __repr__(self) -> str:
568
569
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
570
                f"num_seqs={len(self.seqs_dict)})")
571
572


573
class SequenceGroupMetadata:
574
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
575
576
577
578
579
580
581
582

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
583
584
585
        do_sample: True if sampling is required. Sampling is not required when
            e.g., prefill is chunked, and the current iteration only computes
            query tokens for prefill, we don't need sampling.
586
587
        token_chunk_size: The number of tokens to be processed (per sequence).
            None if chunking is not required.
588
        lora_request: LoRA request.
589
590
591
        computed_block_nums: The block numbers that are already computed,
            used in prefix caching.
        state: Internal state tied to this sequence group.
592
        multi_modal_data: Multi modal data.
593
594
595
596
597
598
599
600
601
        encoder_seq_data: Optional sequence data for encoder prompt
                          (SequenceGroup.encoder_seq). Should be None 
                          unless you are working with an encoder/decoder
                          model.
        cross_block_table: Optional cross-attention block table associated
                           with the encoder prompt
                           (SequenceGroup.encoder_seq). Should be None
                           unless you are working with an encoder/decoder
                           model.
602
    """
603
604
605

    def __init__(
        self,
606
        request_id: str,
607
        is_prompt: bool,
608
        seq_data: Dict[int, SequenceData],
609
        sampling_params: SamplingParams,
610
        block_tables: Dict[int, List[int]],
611
        do_sample: bool = True,
612
        pooling_params: Optional[PoolingParams] = None,
613
        token_chunk_size: Optional[int] = None,
614
        lora_request: Optional[LoRARequest] = None,
615
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
616
        state: Optional[SequenceGroupState] = None,
617
        multi_modal_data: Optional["MultiModalData"] = None,
618
619
        encoder_seq_data: Optional[SequenceData] = None,
        cross_block_table: Optional[List[int]] = None,
620
    ) -> None:
621
        self.request_id = request_id
622
        self.is_prompt = is_prompt
623
        self.seq_data = seq_data
624
625
        self.sampling_params = sampling_params
        self.block_tables = block_tables
626
        self.pooling_params = pooling_params
627
        self.lora_request = lora_request
628
        self.computed_block_nums = computed_block_nums
629
        self.multi_modal_data = multi_modal_data
Nick Hill's avatar
Nick Hill committed
630
        self.state = SequenceGroupState() if state is None else state
631
632
        self.encoder_seq_data = encoder_seq_data
        self.cross_block_table = cross_block_table
633
        self._token_chunk_size = token_chunk_size
634
        self.do_sample = do_sample
635

636
637
638
639
640
641
        # The number of speculative tokens adopted in this request.
        # None means specuative decoding is not used.
        # Zero means speculative decoding is disabled for some reasons.
        # TODO: We should maintain this states out of the sequence group.
        self.num_speculative_tokens = None

642
643
644
645
646
        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
647

648
649
650
651
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

652
    @property
653
    def token_chunk_size(self) -> int:
654
        """Return the number of tokens to be processed (chunk size)."""
655
        assert self._token_chunk_size is not None
656
657
        return self._token_chunk_size

658

Zhuohan Li's avatar
Zhuohan Li committed
659
class SequenceOutput:
660
661
662
663
664
665
666
667
668
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
669
670
671
672
673

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
674
        logprobs: Dict[int, Logprob],
675
676
677
678
679
680
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
681
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
682
683
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
684

685
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
686
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
687
            raise NotImplementedError()
688
689
690
691
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
692
693


694
695
696
697
698
699
700
701
702
703
704
705
706
707
class SequenceGroupOutput(ABC):
    """The base class for model outputs associated with a sequence group."""

    @abstractmethod
    def __repr__(self) -> str:
        pass

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        pass


class CompletionSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with a completion sequence group."""
708
709
710

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
711
        samples: List[SequenceOutput],
712
713
714
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
715
        # Prompt logprob for each prompt query token.
716
717
718
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
719
        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
720
721
                f"prompt_logprobs={self.prompt_logprobs})")

722
    def __eq__(self, other: object) -> bool:
723
        if not isinstance(other, CompletionSequenceGroupOutput):
724
725
726
727
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

728

729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
    """The model output associated with an embedding sequence group."""

    def __init__(
        self,
        embeddings: List[float],
    ) -> None:
        self.embeddings = embeddings

    def __repr__(self) -> str:
        return (f"EmbeddingSequenceGroupOutput("
                f"embeddings_shape={len(self.embeddings)})")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, EmbeddingSequenceGroupOutput):
            raise NotImplementedError()
        return self.embeddings == other.embeddings


748
749
750
751
752
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

753
    This data structure implements methods, so it can be used like a list, but
754
755
756
    also has optional fields for device tensors.
    """

757
    outputs: List[CompletionSequenceGroupOutput]
758
759

    # On-device tensor containing probabilities of each token.
760
    sampled_token_probs: Optional[torch.Tensor] = None
761

762
763
764
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None

765
    # On-device tensor containing the sampled token ids.
766
    sampled_token_ids: Optional[torch.Tensor] = None
767
768
769
770

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

771
772
773
    # Optional last hidden states from the model.
    hidden_states: Optional[torch.Tensor] = None

774
775
776
777
778
779
780
781
782
783
784
785
    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs
786
787
788
789
790
791
792
793
794
795
796
797
798

    def __repr__(self) -> str:
        """Show the shape of a tensor instead of its values to reduce noise.
        """
        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
                                    else self.sampled_token_probs.shape)
        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
                                  self.sampled_token_ids.shape)
        return (
            f"SamplerOutput(outputs={self.outputs}, "
            f"sampled_token_probs={sampled_token_probs_repr}, "
            f"sampled_token_ids={sampled_token_ids_repr}, "
            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
799
800


801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
@dataclass
class PoolerOutput:
    """The output from a pooling operation in the embedding model."""
    outputs: List[EmbeddingSequenceGroupOutput]

    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs


822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
def get_all_seq_ids(
        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
    """Given a list of SequenceGroupMetadata, create a list of all
    sequence ids.
    """
    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]


class HiddenStates:
    """Hidden states corresponding to in-progress sequences.
    Used in speculative decoding to pass hidden states from
    the target model to the proposer model in the subsequent step.

    seq_ids are the sequence ids of each entry of the batch
    dimension of the hidden_states tensor"""

    def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
                 hidden_states: torch.Tensor):
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
        self.hidden_states: torch.Tensor = hidden_states

    def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
               hidden_states: torch.Tensor) -> None:
        """Update hidden states from target model invocation."""
        assert len(seq_group_metadata_list) == len(hidden_states)
        self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
        self.hidden_states = torch.cat([self.hidden_states, hidden_states])

    def prune(self,
              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        """Prune to provided list of sequence ids."""
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
        if seq_ids != self.seq_ids:
            # Batch contents changed - prune removed sequences.
            index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
            self.hidden_states = self.hidden_states[index]
            self.seq_ids = seq_ids


862
863
@dataclass
class ExecuteModelRequest:
864
865
    """The model execution request, containing CPU metadata only. The LLM
    engine should create an instance of this class for each request batch."""
866
867
    # The sequence group metadata list.
    seq_group_metadata_list: List[SequenceGroupMetadata]
868
869
870
871
    # Blocks to swap in. List of CPU -> GPU block number.
    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
    # Blocks to swap out. List of GPU -> CPU block number.
    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
872
873
    # Blocks to copy. Source to dest block.
    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
874
875
876
877
    # The number of slots for lookahead decoding.
    num_lookahead_slots: int = 0
    # The number of requests in the running queue.
    running_queue_size: int = 0
878
879
    # Optional hidden states from prior step.
    previous_hidden_states: Optional[HiddenStates] = None
880
881
    # The number of forward steps to run.
    num_steps: int = 1
882
883
884
885
886
887
888
889
890
891
892
893

    def clone(
        self, seq_group_metadata_list: List[SequenceGroupMetadata]
    ) -> "ExecuteModelRequest":
        """Clone the request with a new sequence group metadata list."""
        return ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
            blocks_to_copy=self.blocks_to_copy.copy(),
            num_lookahead_slots=self.num_lookahead_slots,
            running_queue_size=self.running_queue_size,
894
            previous_hidden_states=self.previous_hidden_states,
895
            num_steps=self.num_steps,
896
        )