"docs/features/compatibility_matrix.md" did not exist on "d0bc2f810b7a34247154b078c2429bf62519e9ca"
input_metadata.py 4.27 KB
Newer Older
1
from dataclasses import dataclass, fields
2
from typing import TYPE_CHECKING, Optional, List, Any, Dict
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4

import torch
5
6
if TYPE_CHECKING:
    from xformers.ops.fmha.attn_bias import AttentionBias
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
@dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
10
class InputMetadata:
11
    """Metadata for input sequences. Used in PagedAttention.
12

13
14
15
16
    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
17
    """
18
19
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
20
    is_prompt: bool
21
22
23
24
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
25
    slot_mapping: torch.Tensor
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    # (batch_size,). The prompt length per sequence. None if it is a decoding.
    prompt_lens: Optional[List[int]]
    # prompt_lens stored as a tensor.
    prompt_lens_tensor: Optional[torch.Tensor]
    # The number of prompt tokens. Doesn't include padding.
    num_prompt_tokens: int
    # The number of generation tokens. Doesn't include padding.
    num_generation_tokens: int
    """
    Definition of context_len, subquery_len, and seqlen.
    |---------- N-1 iteration --------|
    |---------------- N iteration ---------------------|
    |- tokenA -|......................|-- newTokens ---|
    |---------- context_len ----------|
    |-------------------- seqlen ----------------------|
                                      |- subquery_len -|

    WARNING: context_len has different definition depending on if it is
    prefill vs decoding. When it is prefill, it doesn't include new
    tokens. When it is for decoding, it includes a new token.
    """

    # Maximum subquery length in the batch.
    max_subquery_len: Optional[int]
    # Maximum context length in the batch.
51
    max_context_len: Optional[int]
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    # FIXME: It is for flash attn.
    # Maximum sequence length in the batch.
    max_seq_len: Optional[int]
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    subquery_start_loc: Optional[torch.Tensor]
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]
    # (batch_size,). The length of context (tokens stored in KV cache) per
    # sequence. WARNING: When it is a prefill request, it doesn't include new
    # tokens. When it is for decoding, it includes a new token.
67
    context_lens: Optional[torch.Tensor]
68
69
70
71
72
73
    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
74
    block_tables: Optional[torch.Tensor]
75
76
    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
77
78
    use_cuda_graph: bool
    kv_cache_dtype: str
Woosuk Kwon's avatar
Woosuk Kwon committed
79

80
    def __post_init__(self):
81
82
83
84
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
85
        # will not appear in the __repr__ and __init__
86
        self.attn_bias: Optional[List["AttentionBias"]] = None
87
88
89
90

        # Cuda graph is only used for decoding now.
        if self.use_cuda_graph:
            assert self.num_prompt_tokens == 0
91
92
93
94
95
96
97
98
99

    def asdict_zerocopy(self) -> Dict[str, Any]:
        """Similar to dataclasses.asdict, but avoids deepcopying."""
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
            for field in fields(self)
        }