attention.py 7.43 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4

import torch
import torch.nn as nn
5
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
6

7
8
from cacheflow import attention_ops
from cacheflow import cache_ops
9
from cacheflow import pos_encoding_ops
10
from cacheflow.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12


13
class GPTCacheFlowAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15

    def __init__(self, scale: float) -> None:
16
        super().__init__()
17
        self.scale = float(scale)
18
        self.attn_op = xops.fmha.cutlass.FwOp()
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
21

    def multi_query_kv_attention(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25
        output: torch.Tensor,                   # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,                      # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
26
        attn_bias: xops.AttentionBias,
Woosuk Kwon's avatar
Woosuk Kwon committed
27
    ) -> None:
28
29
30
31
32
33
34
35
36
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=self.scale,
            op=self.attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        )
38
39
40
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43

    def single_query_cached_kv_attention(
        self,
44
45
46
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
47
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
        input_metadata: InputMetadata,
    ) -> None:
50
51
52
53
54
55
56
57
        head_size = value_cache.shape[2]
        supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
        if head_size not in supported_head_sizes:
            raise ValueError(f'head_size ({head_size}) is not supported by '
                             'the single_query_cached_kv_attention kernel. '
                             'Use one of the following head sizes: '
                             f'{supported_head_sizes}.')

58
59
60
61
62
63
64
65
66
67
68
69
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72

    def forward(
        self,
73
74
75
76
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
77
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
80
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
        # NOTE: The query, key, and value tensors must be sliced from a qkv
        # tensor of shape [num_tokens, 3 * num_heads * head_size].
Woosuk Kwon's avatar
Woosuk Kwon committed
83

Woosuk Kwon's avatar
Woosuk Kwon committed
84
        # Reshape the query, key, and value tensors.
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        num_heads = value_cache.shape[1]
86
        head_size = value_cache.shape[2]
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
97
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
101
                value[:num_prompt_tokens],
102
                input_metadata.attn_bias,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
107
108
109

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113
114
115
116
117
118
119
        num_valid_tokens = input_metadata.num_valid_tokens
        if num_valid_tokens > 0:
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
122
123

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
                output[num_prompt_tokens:num_valid_tokens],
                query[num_prompt_tokens:num_valid_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
130
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
131
        # NOTE(woosuk): The output tensor may include paddings.
Woosuk Kwon's avatar
Woosuk Kwon committed
132
        return output.view(-1, num_heads * head_size)
133
134


135
136
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
    """Attention with GPT-NeoX style rotary embedding."""
137
138
139
140

    def __init__(
        self,
        scale: float,
141
        rotary_dim: int,
142
143
144
145
146
147
        max_position: int = 8192,
        base: int = 10000,
    ) -> None:
        super().__init__(scale)

        # Create the cos and sin cache.
148
        inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
149
150
151
152
153
154
155
156
157
158
        t = torch.arange(max_position).float()
        freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
        # initializing the model. Make it more robust.
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
159
        # Embedding size: [max_position, rotary_dim]
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        self.register_buffer('cos_sin_cache', cache, persistent=False)

    def forward(
        self,
        positions: torch.LongTensor,            # [num_tokens]
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
175
        head_size = value_cache.shape[2]
176
177
178
179
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
180
            head_size,
181
182
183
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
            query,
            key,
186
187
188
189
190
191
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )