input_metadata.py 1.47 KB
Newer Older
1
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

class InputMetadata:
7
    """Metadata for input sequences. Used in PagedAttention.
8
9
10
11
12

    Args:
        prompt_lens: Lengths of prompts.
        slot_mapping: The address to write the new KV to of each token.
        max_context_len: The maximum context length.
13
        context_lens: the length of attention context for each sequence.
14
15
        block_tables: The block tables. (Seq id -> list of physical block)
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18

    def __init__(
        self,
19
        is_prompt: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        slot_mapping: torch.Tensor,
21
22
23
        max_context_len: Optional[int],
        context_lens: Optional[torch.Tensor],
        block_tables: Optional[torch.Tensor],
24
        use_cuda_graph: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
25
    ) -> None:
26
        self.is_prompt = is_prompt
27
        self.max_context_len = max_context_len
Woosuk Kwon's avatar
Woosuk Kwon committed
28
        self.slot_mapping = slot_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
        self.context_lens = context_lens
        self.block_tables = block_tables
31
        self.use_cuda_graph = use_cuda_graph
Woosuk Kwon's avatar
Woosuk Kwon committed
32

Woosuk Kwon's avatar
Woosuk Kwon committed
33
        # Set during the execution of the first attention op.
34
35
        # FIXME(woosuk): This is a hack.
        self.attn_bias = None
Woosuk Kwon's avatar
Woosuk Kwon committed
36

Woosuk Kwon's avatar
Woosuk Kwon committed
37
    def __repr__(self) -> str:
38
        return ("InputMetadata("
39
                f"is_prompt={self.is_prompt}, "
40
41
42
                f"max_context_len={self.max_context_len}, "
                f"slot_mapping={self.slot_mapping}, "
                f"context_lens={self.context_lens}, "
43
44
                f"block_tables={self.block_tables}, "
                f"use_cuda_graph={self.use_cuda_graph})")