input_metadata.py 2.21 KB
Newer Older
1
from typing import List, Dict, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4

import torch

5
6
from cacheflow.sampling_params import SamplingParams

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11

class InputMetadata:

    def __init__(
        self,
12
13
        seq_groups: List[Tuple[List[int], SamplingParams]],
        seq_logprobs: Dict[int, float],                         # Seq id -> cumulative logprobs.
Woosuk Kwon's avatar
Woosuk Kwon committed
14
        prompt_lens: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
15
        cumulative_prompt_lens: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20
        slot_mapping: torch.Tensor,
        context_lens: torch.Tensor,
        max_context_len: int,
        block_tables: torch.Tensor,
    ) -> None:
21
22
        self.seq_groups = seq_groups
        self.seq_logprobs = seq_logprobs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
        self.prompt_lens = prompt_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
24
        self.cumulative_prompt_lens = cumulative_prompt_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
25
        self.slot_mapping = slot_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
        self.context_lens = context_lens
        self.max_context_len = max_context_len
        self.block_tables = block_tables

        self.num_prompts = len(prompt_lens)
31
        self.num_prompt_tokens = sum(prompt_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
Woosuk Kwon's avatar
Woosuk Kwon committed
33
        self.num_generation_tokens = context_lens.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
34
        self.num_valid_tokens = slot_mapping.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37
38
        if block_tables.numel() > 0:
            self.max_num_blocks_per_seq = block_tables.shape[1]
        else:
            self.max_num_blocks_per_seq = 0
39
40
        assert block_tables.shape[0] == self.num_generation_tokens
        assert context_lens.shape[0] == self.num_generation_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
44

    def __repr__(self) -> str:
        return (f'InputMetadata('
                f'num_prompts={self.num_prompts}, '
45
                f'num_prompt_tokens={self.num_prompt_tokens}, '
Woosuk Kwon's avatar
Woosuk Kwon committed
46
                f'max_prompt_len={self.max_prompt_len}, '
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
                f'num_generation_tokens={self.num_generation_tokens}, '
                f'num_valid_tokens={self.num_valid_tokens}, '
                f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
Zhuohan Li's avatar
Zhuohan Li committed
50
51
                f'max_context_len={self.max_context_len}), '
                f'prompt_lens={self.prompt_lens}, '
Woosuk Kwon's avatar
Woosuk Kwon committed
52
                f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
Zhuohan Li's avatar
Zhuohan Li committed
53
54
55
                f'slot_mapping={self.slot_mapping}, '
                f'context_lens={self.context_lens}, '
                f'block_tables={self.block_tables})')