input_metadata.py 1.42 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
from typing import List

import torch


class InputMetadata:

    def __init__(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
10
        seq_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
        prompt_lens: List[int],
        slot_mapping: torch.Tensor,
        context_lens: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
14
        # FIXME: Rename
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
        max_context_len: int,
        block_tables: torch.Tensor,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
18
        self.seq_ids = seq_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
19
        self.prompt_lens = prompt_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        self.slot_mapping = slot_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25
26
        self.context_lens = context_lens
        self.max_context_len = max_context_len
        self.block_tables = block_tables

        self.num_prompts = len(prompt_lens)
        self.num_generation_tokens = context_lens.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
27
        self.num_valid_tokens = len(slot_mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
        if block_tables.numel() > 0:
            self.max_num_blocks_per_seq = block_tables.shape[1]
        else:
            self.max_num_blocks_per_seq = 0
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        assert self.num_generation_tokens == block_tables.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
33
        assert self.num_prompts + self.num_generation_tokens == len(seq_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
40
41
42

    def __repr__(self) -> str:
        return (f'InputMetadata('
                f'seq_ids={self.seq_ids}, '
                f'num_prompts={self.num_prompts}, '
                f'num_generation_tokens={self.num_generation_tokens}, '
                f'num_valid_tokens={self.num_valid_tokens}, '
                f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
                f'max_context_len={self.max_context_len})')