input_metadata.py 1.41 KB
Newer Older
1
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3

import torch
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

class InputMetadata:
7
    """Metadata for input sequences. Used in PagedAttention.
8
9
10
11
12

    Args:
        prompt_lens: Lengths of prompts.
        slot_mapping: The address to write the new KV to of each token.
        max_context_len: The maximum context length.
13
        context_lens: the length of attention context for each sequence.
14
15
        block_tables: The block tables. (Seq id -> list of physical block)
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20

    def __init__(
        self,
        prompt_lens: List[int],
        slot_mapping: torch.Tensor,
21
22
23
        max_context_len: Optional[int],
        context_lens: Optional[torch.Tensor],
        block_tables: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
    ) -> None:
        self.prompt_lens = prompt_lens
26
        self.max_context_len = max_context_len
Woosuk Kwon's avatar
Woosuk Kwon committed
27
        self.slot_mapping = slot_mapping
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
        self.context_lens = context_lens
        self.block_tables = block_tables
Woosuk Kwon's avatar
Woosuk Kwon committed
30

31
        self.is_prompt = len(prompt_lens) > 0
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        # Set during the execution of the first attention op.
33
34
        # FIXME(woosuk): This is a hack.
        self.attn_bias = None
Woosuk Kwon's avatar
Woosuk Kwon committed
35

Woosuk Kwon's avatar
Woosuk Kwon committed
36
    def __repr__(self) -> str:
37
38
39
40
41
42
        return ("InputMetadata("
                f"prompt_lens={self.prompt_lens}, "
                f"max_context_len={self.max_context_len}, "
                f"slot_mapping={self.slot_mapping}, "
                f"context_lens={self.context_lens}, "
                f"block_tables={self.block_tables})")