input_metadata.py 1.54 KB
Newer Older
1
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

class InputMetadata:
7
    """Metadata for input sequences. Used in PagedAttention.
8
9
10
11
12

    Args:
        prompt_lens: Lengths of prompts.
        slot_mapping: The address to write the new KV to of each token.
        max_context_len: The maximum context length.
13
        context_lens: the length of attention context for each sequence.
14
15
        block_tables: The block tables. (Seq id -> list of physical block)
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20

    def __init__(
        self,
        prompt_lens: List[int],
        slot_mapping: torch.Tensor,
21
22
23
        max_context_len: Optional[int],
        context_lens: Optional[torch.Tensor],
        block_tables: Optional[torch.Tensor],
24
        use_cuda_graph: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
    ) -> None:
        self.prompt_lens = prompt_lens
27
        self.max_context_len = max_context_len
Woosuk Kwon's avatar
Woosuk Kwon committed
28
        self.slot_mapping = slot_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
        self.context_lens = context_lens
        self.block_tables = block_tables
31
        self.use_cuda_graph = use_cuda_graph
Woosuk Kwon's avatar
Woosuk Kwon committed
32

33
        self.is_prompt = len(prompt_lens) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
34
        # Set during the execution of the first attention op.
35
36
        # FIXME(woosuk): This is a hack.
        self.attn_bias = None
Woosuk Kwon's avatar
Woosuk Kwon committed
37

Woosuk Kwon's avatar
Woosuk Kwon committed
38
    def __repr__(self) -> str:
39
40
41
42
43
        return ("InputMetadata("
                f"prompt_lens={self.prompt_lens}, "
                f"max_context_len={self.max_context_len}, "
                f"slot_mapping={self.slot_mapping}, "
                f"context_lens={self.context_lens}, "
44
45
                f"block_tables={self.block_tables}, "
                f"use_cuda_graph={self.use_cuda_graph})")