input_metadata.py 855 Bytes
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
from typing import List

import torch


class InputMetadata:

    def __init__(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
10
        seq_ids: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13
14
15
16
        prompt_lens: List[int],
        slot_mapping: torch.Tensor,
        context_lens: torch.Tensor,
        max_context_len: int,
        block_tables: torch.Tensor,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        self.seq_ids = seq_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
18
        self.prompt_lens = prompt_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
19
        self.slot_mapping = slot_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
24
25
26
27
        self.context_lens = context_lens
        self.max_context_len = max_context_len
        self.block_tables = block_tables

        self.num_prompts = len(prompt_lens)
        self.num_generation_tokens = context_lens.shape[0]
        self.max_num_blocks_per_seq = block_tables.shape[1]
        assert self.num_generation_tokens == block_tables.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
28
        assert self.num_prompts + self.num_generation_tokens == len(seq_ids)