input_metadata.py 1.32 KB
Newer Older
1
2
from dataclasses import dataclass, fields
from typing import Optional, Any, Dict
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
@dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
8
class InputMetadata:
9
    """Metadata for input sequences. Used in PagedAttention.
10
11
12
13
14

    Args:
        prompt_lens: Lengths of prompts.
        slot_mapping: The address to write the new KV to of each token.
        max_context_len: The maximum context length.
15
        context_lens: the length of attention context for each sequence.
16
        block_tables: The block tables. (Seq id -> list of physical block)
17
        kv_cache_dtype: Data type to store kv cache.
18
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
21
22
23
24
25
26
27
28
29
    is_prompt: bool
    slot_mapping: torch.Tensor
    prompt_lens: Optional[torch.Tensor]
    max_seq_len: Optional[int]
    start_loc: Optional[torch.Tensor]
    max_context_len: Optional[int]
    context_lens: Optional[torch.Tensor]
    block_tables: Optional[torch.Tensor]
    use_cuda_graph: bool
    kv_cache_dtype: str
Woosuk Kwon's avatar
Woosuk Kwon committed
30

31
32
    def __post_init__(self):
        # will not appear in the __repr__ and __init__
33
        self.attn_bias = None
34
35
36
37
38
39
40
41
42

    def asdict_zerocopy(self) -> Dict[str, Any]:
        """Similar to dataclasses.asdict, but avoids deepcopying."""
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
            for field in fields(self)
        }