sequence.py 22.7 KB
Newer Older
1
"""Sequence and its related classes."""
2
import copy
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Dict, List, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm.block import LogicalTokenBlock
8
from vllm.lora.request import LoRARequest
9
from vllm.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
if TYPE_CHECKING:
    import torch
13

14
15
    from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

16
17
18

@dataclass
class Logprob:
19
20
21
22
23
24
25
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """
26
    logprob: float
27
    rank: Optional[int] = None
28
29
30
31
32
    decoded_token: Optional[str] = None


PromptLogprobs = List[Optional[Dict[int, Logprob]]]
SampleLogprobs = List[Dict[int, Logprob]]
33

Woosuk Kwon's avatar
Woosuk Kwon committed
34
35

class SequenceStatus(enum.Enum):
36
    """Status of a sequence."""
37
    WAITING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
38
    RUNNING = enum.auto()
Woosuk Kwon's avatar
Woosuk Kwon committed
39
    SWAPPED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
40
41
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
42
    FINISHED_ABORTED = enum.auto()
Lily Liu's avatar
Lily Liu committed
43
    FINISHED_IGNORED = enum.auto()
Zhuohan Li's avatar
Zhuohan Li committed
44
45
46
47
48
49

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
50
            SequenceStatus.FINISHED_ABORTED,
51
            SequenceStatus.FINISHED_IGNORED,
Zhuohan Li's avatar
Zhuohan Li committed
52
53
54
55
56
57
58
59
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
60
61
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
Lily Liu's avatar
Lily Liu committed
62
        elif status == SequenceStatus.FINISHED_IGNORED:
63
64
65
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
Lily Liu's avatar
Lily Liu committed
66
            finish_reason = "length"
Zhuohan Li's avatar
Zhuohan Li committed
67
68
69
        else:
            finish_reason = None
        return finish_reason
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71

72
73
74
75
@dataclass
class RequestMetrics:
    """Metrics associated with a request.

76
    Attributes:
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        arrival_time: The time when the request arrived.
        first_scheduled_time: The time when the request was first scheduled.
        first_token_time: The time when the first token was generated.
        time_in_queue: The time the request spent in the queue.
        finished_time: The time when the request was finished.
    """
    arrival_time: float
    last_token_time: float
    first_scheduled_time: Optional[float]
    first_token_time: Optional[float]
    time_in_queue: Optional[float]
    finished_time: Optional[float] = None


91
class SequenceData:
92
93
94
95
    """Data associated with a sequence.

    Args:
        prompt_token_ids: The token IDs of the prompt.
96
97
        output_token_ids: The token IDs of the output. Set to an empty list if
            None.
98
99
100
101
102
103

    Attributes:
        prompt_token_ids: The token IDs of the prompt.
        output_token_ids: The token IDs of the output.
        cumulative_logprob: The cumulative log probability of the output.
    """
104
105
106
107

    def __init__(
        self,
        prompt_token_ids: List[int],
108
        output_token_ids: Optional[List[int]] = None,
109
    ) -> None:
110
111
112
        if output_token_ids is None:
            output_token_ids = []

113
        self.prompt_token_ids = prompt_token_ids
114
        self.output_token_ids = output_token_ids
115
        self.cumulative_logprob = 0.0
116
117
        # The number of tokens that are computed (that run against the model).
        self._num_computed_tokens = 0
118

119
    def append_token_id(self, token_id: int, logprob: float) -> None:
120
121
        self.output_token_ids.append(token_id)
        self.cumulative_logprob += logprob
122
123
124
125

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

126
127
128
    def get_prompt_len(self) -> int:
        return len(self.prompt_token_ids)

129
130
131
    def get_output_len(self) -> int:
        return len(self.output_token_ids)

132
133
134
    def get_token_ids(self) -> List[int]:
        return self.prompt_token_ids + self.output_token_ids

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    def get_num_computed_tokens(self) -> int:
        """Return the number of prefill tokens that are already computed."""
        return self._num_computed_tokens

    def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int:
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens

    def reset_num_computed_tokens(self) -> None:
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0

    def get_num_uncomputed_tokens(self) -> int:
        """Return the number of prefil tokens that are not computed."""
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

157
158
159
160
161
    def get_last_token_id(self) -> int:
        if not self.output_token_ids:
            return self.prompt_token_ids[-1]
        return self.output_token_ids[-1]

162
163
164
165
166
167
    def get_prompt_token_ids(self) -> int:
        return self.prompt_token_ids

    def get_output_token_ids(self) -> int:
        return self.output_token_ids

168
169
170
    def __repr__(self) -> str:
        return (f"SequenceData("
                f"prompt_token_ids={self.prompt_token_ids}, "
171
172
                f"output_token_ids={self.output_token_ids}, "
                f"cumulative_logprob={self.cumulative_logprob})")
173
174


Woosuk Kwon's avatar
Woosuk Kwon committed
175
class Sequence:
176
177
178
179
180
181
182
183
    """Stores the data, status, and block information of a sequence.

    Args:
        seq_id: The ID of the sequence.
        prompt: The prompt of the sequence.
        prompt_token_ids: The token IDs of the prompt.
        block_size: The block size of the sequence. Should be the same as the
            block size used by the block manager and cache engine.
184
        lora_request: LoRA request.
185
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
188
189

    def __init__(
        self,
        seq_id: int,
190
        prompt: str,
191
        prompt_token_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
192
        block_size: int,
Cade Daniel's avatar
Cade Daniel committed
193
        eos_token_id: Optional[int] = None,
194
        lora_request: Optional[LoRARequest] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
    ) -> None:
        self.seq_id = seq_id
197
        self.prompt = prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        self.block_size = block_size
199
        self.eos_token_id = eos_token_id
200
        self.lora_request = lora_request
Woosuk Kwon's avatar
Woosuk Kwon committed
201

202
        self.data = SequenceData(prompt_token_ids)
203
        self.output_logprobs: SampleLogprobs = []
204
        self.output_text = ""
205

Woosuk Kwon's avatar
Woosuk Kwon committed
206
        self.logical_token_blocks: List[LogicalTokenBlock] = []
207
        # Initialize the logical token blocks with the prompt token ids.
208
        self._append_tokens_to_blocks(prompt_token_ids)
209
        self.status = SequenceStatus.WAITING
210
        self.stop_reason: Union[int, str, None] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
211

212
213
214
215
216
217
        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

218
219
220
221
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

222
    def hash_of_block(self, logical_idx: int) -> int:
223
224
        # TODO This can produce incorrect hash when block size > prompt size

225
        # Compute the number of tokens in the sequence
226
227
        # TODO: The current hashing function is O(L^2). We should optimize
        # this in the future.
228
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
229
230
        return hash(
            (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
231
232
233
234

    def num_hashed_tokens_of_block(self, logical_idx: int):
        return logical_idx * self.block_size + self.block_size

235
236
237
238
    def reset_state_for_recompute(self):
        """Reset the sequence states for recomputation."""
        self.data.reset_num_computed_tokens()

239
    def _append_logical_block(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
243
244
245
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

246
    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
247
248
        cursor = 0
        while cursor < len(token_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
249
            if not self.logical_token_blocks:
250
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253

            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
254
                self._append_logical_block()
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
                last_block = self.logical_token_blocks[-1]

            num_empty_slots = last_block.get_num_empty_slots()
258
259
260
            last_block.append_tokens(token_ids[cursor:cursor +
                                               num_empty_slots])
            cursor += num_empty_slots
Woosuk Kwon's avatar
Woosuk Kwon committed
261

262
263
264
    def append_token_id(
        self,
        token_id: int,
265
        logprobs: Dict[int, Logprob],
266
    ) -> None:
267
        assert token_id in logprobs
268
        self._append_tokens_to_blocks([token_id])
269
        self.output_logprobs.append(logprobs)
270
        self.data.append_token_id(token_id, logprobs[token_id].logprob)
271

Woosuk Kwon's avatar
Woosuk Kwon committed
272
    def get_len(self) -> int:
273
        return self.data.get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
274

275
276
277
    def get_prompt_len(self) -> int:
        return self.data.get_prompt_len()

278
279
280
    def get_output_len(self) -> int:
        return self.data.get_output_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
281
    def get_token_ids(self) -> List[int]:
282
        return self.data.get_token_ids()
Woosuk Kwon's avatar
Woosuk Kwon committed
283

284
285
286
    def get_prompt_token_ids(self) -> List[int]:
        return self.data.get_prompt_token_ids()

287
    def get_last_token_id(self) -> int:
288
        return self.data.get_last_token_id()
289

290
291
292
293
294
295
    def get_output_token_ids(self) -> List[int]:
        return self.data.output_token_ids

    def get_cumulative_logprob(self) -> float:
        return self.data.cumulative_logprob

296
    def get_beam_search_score(self,
297
                              length_penalty: float = 1.0,
298
299
300
301
302
303
304
305
306
307
                              seq_len: Optional[int] = None,
                              eos_token_id: Optional[int] = None) -> float:
        """Calculate the beam search score with length penalty.

        Adapted from

        https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
        """
        if seq_len is None:
            seq_len = self.get_len()
308
            # NOTE: HF implementation does not count the EOS token
309
310
311
312
313
314
            # towards the length, we align with that here for testing.
            if (eos_token_id is not None
                    and self.get_last_token_id() == eos_token_id):
                seq_len -= 1
        return self.get_cumulative_logprob() / (seq_len**length_penalty)

315
316
317
    def is_finished(self) -> bool:
        return SequenceStatus.is_finished(self.status)

318
319
320
321
    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq
322

Woosuk Kwon's avatar
Woosuk Kwon committed
323
    def __repr__(self) -> str:
324
325
326
        return (f"Sequence(seq_id={self.seq_id}, "
                f"status={self.status.name}, "
                f"num_blocks={len(self.logical_token_blocks)})")
Woosuk Kwon's avatar
Woosuk Kwon committed
327

Woosuk Kwon's avatar
Woosuk Kwon committed
328

Nick Hill's avatar
Nick Hill committed
329
330
331
332
333
334
335
336
@dataclass
class SequenceGroupState:
    """Mutable state tied to a specific sequence group"""

    # torch.Generator used in seeded sampling
    generator: Optional = None


337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
class MultiModalData:
    """Multi modal request.
    
    Args:
        type: The data type.
        data: The actual data.
        The required shape and semantic meaning of it depends on the vision
        language config of the hosted model. 
        See `VisionLanguageConfig` in `config.py`.
    """

    class Type(enum.Enum):
        IMAGE = enum.auto()

    def __init__(self, type: Type, data: "torch.Tensor"):
        self.type = type
        self.data = data


Woosuk Kwon's avatar
Woosuk Kwon committed
356
class SequenceGroup:
357
358
359
360
361
362
363
    """A group of sequences that are generated from the same prompt.

    Args:
        request_id: The ID of the request.
        seqs: The list of sequences.
        sampling_params: The sampling parameters used to generate the outputs.
        arrival_time: The arrival time of the request.
364
        lora_request: LoRA request.
365
        multi_modal_data: Multi modal data associated with the request.
366
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
367
368
369

    def __init__(
        self,
370
        request_id: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
371
        seqs: List[Sequence],
372
        sampling_params: SamplingParams,
373
        arrival_time: float,
374
        lora_request: Optional[LoRARequest] = None,
375
        multi_modal_data: Optional[MultiModalData] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
376
    ) -> None:
377
        self.request_id = request_id
378
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
379
        self.sampling_params = sampling_params
380
381
382
383
384
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
385
        self.lora_request = lora_request
386
        self.prompt_logprobs: Optional[PromptLogprobs] = None
Nick Hill's avatar
Nick Hill committed
387
        self.state = SequenceGroupState()
388
        self.multi_modal_data = multi_modal_data
389
390
391
392
393
394
395
396
397
398
399
400

    @property
    def prompt(self) -> str:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).prompt

    @property
    def prompt_token_ids(self) -> List[int]:
        # All sequences in the group should have the same prompt.
        # We use the prompt of an arbitrary sequence.
        return next(iter(self.seqs_dict.values())).data.prompt_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
401

402
403
404
405
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

406
407
    def get_last_latency(self, now: float) -> float:
        """Gets last token latency for Request level timings."""
408
409
        latency = now - self.metrics.last_token_time
        self.metrics.last_token_time = now
410
411
        return latency

412
413
414
415
416
417
    def maybe_set_first_token_time(self, time: float) -> None:
        """Sets the first token time for Request level timings."""
        if self.metrics.first_token_time is None:
            self.metrics.first_token_time = time

    def maybe_set_first_scheduled_time(self, time: float) -> None:
418
419
        """Sets the first scheduled time and time in queue for Request
        level timings."""
420
421
422
423
424
425
426
427
        if self.metrics.first_scheduled_time is None:
            self.metrics.first_scheduled_time = time
            self.metrics.time_in_queue = time - self.metrics.arrival_time

    def set_finished_time(self, time: Optional[float]) -> None:
        """Sets the finished time for Request level timings."""
        self.metrics.finished_time = time

428
429
430
431
432
433
434
435
436
437
438
439
440
441
    def get_max_num_running_seqs(self) -> int:
        """The maximum number of sequences running in parallel in the remaining
        lifetime of the request."""
        if self.sampling_params.use_beam_search:
            # For beam search, maximally there will always be `best_of` beam
            # candidates running in the future.
            return self.sampling_params.best_of
        else:
            if self.sampling_params.best_of > self.num_seqs():
                # At prompt stage, the sequence group is not yet filled up
                # and only have one sequence running. However, in the
                # generation stage, we will have `best_of` sequences running.
                return self.sampling_params.best_of
            # At sampling stages, return the number of actual sequences
442
443
            # that are not finished yet.
            return self.num_unfinished_seqs()
444

445
446
447
448
    def get_seqs(
        self,
        status: Optional[SequenceStatus] = None,
    ) -> List[Sequence]:
449
450
451
        return list(self.seqs_dict.values()) if status is None else [
            seq for seq in self.seqs_dict.values() if seq.status == status
        ]
452

453
454
455
456
457
    def get_unfinished_seqs(self) -> List[Sequence]:
        return [
            seq for seq in self.seqs_dict.values() if not seq.is_finished()
        ]

458
459
    def get_finished_seqs(self) -> List[Sequence]:
        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
460

461
462
463
464
465
466
467
468
469
470
471
472
    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        for seq in self.seqs_dict.values():
            seq.data.update_num_computed_tokens(num_new_computed_tokens)

    def get_num_uncomputed_tokens(self) -> int:
        # All sequences in the group should have the same prompt, so the
        # number of unfinished prefill tokens are the same across all
        # sequences.
        return list(
            self.seqs_dict.values())[0].data.get_num_uncomputed_tokens()

473
474
    def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
        return len(self.get_seqs(status))
475

476
477
478
479
480
481
    def num_unfinished_seqs(self) -> int:
        return len(self.get_unfinished_seqs())

    def num_finished_seqs(self) -> int:
        return len(self.get_finished_seqs())

482
    def find(self, seq_id: int) -> Sequence:
483
484
485
486
487
488
489
490
491
492
493
494
495
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        return self.seqs_dict[seq_id]

    def add(self, seq: Sequence) -> None:
        if seq.seq_id in self.seqs_dict:
            raise ValueError(f"Sequence {seq.seq_id} already exists.")
        self.seqs_dict[seq.seq_id] = seq

    def remove(self, seq_id: int) -> None:
        if seq_id not in self.seqs_dict:
            raise ValueError(f"Sequence {seq_id} not found.")
        del self.seqs_dict[seq_id]
Woosuk Kwon's avatar
Woosuk Kwon committed
496

Woosuk Kwon's avatar
Woosuk Kwon committed
497
    def is_finished(self) -> bool:
498
        return all(seq.is_finished() for seq in self.get_seqs())
Woosuk Kwon's avatar
Woosuk Kwon committed
499

Woosuk Kwon's avatar
Woosuk Kwon committed
500
    def __repr__(self) -> str:
501
502
        return (f"SequenceGroup(request_id={self.request_id}, "
                f"sampling_params={self.sampling_params}, "
503
                f"num_seqs={len(self.seqs_dict)})")
504
505


506
class SequenceGroupMetadata:
507
    """Metadata for a sequence group. Used to create `AttentionMetadata`.
508
509
510
511
512
513
514
515

    Args:
        request_id: The ID of the request.
        is_prompt: Whether the request is at prompt stage.
        seq_data: The sequence data. (Seq id -> sequence data)
        sampling_params: The sampling parameters used to generate the outputs.
        block_tables: The block tables. (Seq id -> list of physical block
            numbers)
516
517
        token_chunk_size: The number of tokens to be processed. None if
            chunking is not required.
Nick Hill's avatar
Nick Hill committed
518
        state: Internal state tied to this sequence group.
519
        lora_request: LoRA request.
520
        multi_modal_data: Multi modal data.
521
    """
522
523
524

    def __init__(
        self,
525
        request_id: str,
526
        is_prompt: bool,
527
        seq_data: Dict[int, SequenceData],
528
        sampling_params: SamplingParams,
529
        block_tables: Dict[int, List[int]],
530
        token_chunk_size: Optional[int] = None,
531
        lora_request: Optional[LoRARequest] = None,
532
        computed_block_nums: Optional[List[int]] = None,
Nick Hill's avatar
Nick Hill committed
533
        state: Optional[SequenceGroupState] = None,
534
        multi_modal_data: Optional[MultiModalData] = None,
535
    ) -> None:
536
        self.request_id = request_id
537
        self.is_prompt = is_prompt
538
        self.seq_data = seq_data
539
540
        self.sampling_params = sampling_params
        self.block_tables = block_tables
541
        self.lora_request = lora_request
542
        self.computed_block_nums = computed_block_nums
543
        self.multi_modal_data = multi_modal_data
Nick Hill's avatar
Nick Hill committed
544
        self.state = SequenceGroupState() if state is None else state
545
546
547
548
549
550
551
        self._token_chunk_size = token_chunk_size

        if self._token_chunk_size is None:
            if is_prompt:
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                self._token_chunk_size = 1
552

553
554
555
556
    @property
    def lora_int_id(self) -> int:
        return self.lora_request.lora_int_id if self.lora_request else 0

557
558
559
560
561
    @property
    def token_chunk_size(self) -> int:
        """Return the number of tokens to be processed (chunk size)."""
        return self._token_chunk_size

562

Zhuohan Li's avatar
Zhuohan Li committed
563
class SequenceOutput:
564
565
566
567
568
569
570
571
572
    """The model output associated with a sequence.

    Args:
        parent_seq_id: The ID of the parent sequence (for forking in beam
            search).
        output_token: The output token ID.
        logprobs: The logprobs of the output token.
            (Token id -> logP(x_i+1 | x_0, ..., x_i))
    """
573
574
575
576
577

    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
578
        logprobs: Dict[int, Logprob],
579
580
581
582
583
584
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
585
        return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
586
587
                f"output_token={self.output_token}, "
                f"logprobs={self.logprobs})")
Zhuohan Li's avatar
Zhuohan Li committed
588

589
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
590
        if not isinstance(other, SequenceOutput):
Zhuohan Li's avatar
Zhuohan Li committed
591
            raise NotImplementedError()
592
593
594
595
        equal = (self.parent_seq_id == other.parent_seq_id
                 and self.output_token == other.output_token)
        log_probs_equal = other.logprobs == self.logprobs
        return equal and log_probs_equal
596
597


Zhuohan Li's avatar
Zhuohan Li committed
598
599
class SequenceGroupOutput:
    """The model output associated with a sequence group."""
600
601
602

    def __init__(
        self,
Zhuohan Li's avatar
Zhuohan Li committed
603
        samples: List[SequenceOutput],
604
605
606
607
608
609
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
        self.prompt_logprobs = prompt_logprobs

    def __repr__(self) -> str:
Zhuohan Li's avatar
Zhuohan Li committed
610
        return (f"SequenceGroupOutput(samples={self.samples}, "
611
612
                f"prompt_logprobs={self.prompt_logprobs})")

613
    def __eq__(self, other: object) -> bool:
Zhuohan Li's avatar
Zhuohan Li committed
614
        if not isinstance(other, SequenceGroupOutput):
615
616
617
618
            raise NotImplementedError()
        return (self.samples == other.samples
                and self.prompt_logprobs == other.prompt_logprobs)

619

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
@dataclass
class SamplerOutput:
    """For each sequence group, we generate a list of SequenceOutput object,
    each of which contains one possible candidate for the next token.

    This datastructure implements methods so it can be used like a list, but
    also has optional fields for device tensors.
    """

    outputs: List[SequenceGroupOutput]

    # On-device tensor containing probabilities of each token.
    sampled_token_probs: Optional["torch.Tensor"] = None

    # On-device tensor containing the sampled token ids.
    sampled_token_ids: Optional["torch.Tensor"] = None

    # Spec decode metrics populated by workers.
    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None

    def __getitem__(self, idx: int):
        return self.outputs[idx]

    def __setitem__(self, idx: int, value):
        self.outputs[idx] = value

    def __len__(self):
        return len(self.outputs)

    def __eq__(self, other: object):
        return isinstance(other,
                          self.__class__) and self.outputs == other.outputs