input_metadata.py 999 Bytes
Newer Older
1
from dataclasses import dataclass
2
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
@dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
8
class InputMetadata:
9
    """Metadata for input sequences. Used in PagedAttention.
10
11
12
13
14

    Args:
        prompt_lens: Lengths of prompts.
        slot_mapping: The address to write the new KV to of each token.
        max_context_len: The maximum context length.
15
        context_lens: the length of attention context for each sequence.
16
        block_tables: The block tables. (Seq id -> list of physical block)
17
        kv_cache_dtype: Data type to store kv cache.
18
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
21
22
23
24
25
26
27
28
29
    is_prompt: bool
    slot_mapping: torch.Tensor
    prompt_lens: Optional[torch.Tensor]
    max_seq_len: Optional[int]
    start_loc: Optional[torch.Tensor]
    max_context_len: Optional[int]
    context_lens: Optional[torch.Tensor]
    block_tables: Optional[torch.Tensor]
    use_cuda_graph: bool
    kv_cache_dtype: str
Woosuk Kwon's avatar
Woosuk Kwon committed
30

31
32
    def __post_init__(self):
        # will not appear in the __repr__ and __init__
33
        self.attn_bias = None