attention.py 5.55 KB
Newer Older
1
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2

3
from flash_attn.flash_attention import FlashAttention
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
import torch
import torch.nn as nn

7
8
from cacheflow import attention_ops
from cacheflow import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
11
12
13
14
15
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

    def __init__(self, scale: float) -> None:
        super().__init__()
16
        self.scale = float(scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
        self.flash_attn = FlashAttention(softmax_scale=self.scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
21

    def multi_query_kv_attention(
        self,
22
23
24
25
26
        output: torch.Tensor,       # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,        # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,          # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,        # [num_prompt_tokens, num_heads, head_size]
        prompt_lens: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
27
    ) -> None:
28
29
30
31
32
33
34
35
36
        if query.dtype == torch.float:
            raise ValueError('The float data type is not supported by '
                             'FlashAttention. Use the half data type instead.')
        head_size = query.shape[2]
        if head_size > 128:
            raise ValueError('FlashAttention does not support head_size > 128.')

        device = query.device
        prefix_sum = [0]
37
        for prompt_len in prompt_lens:
38
39
40
41
42
43
44
45
46
47
48
49
50
            prefix_sum.append(prefix_sum[-1] + prompt_len)
        prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
        max_prompt_len = max(prompt_lens)

        # FIXME(woosuk): Unnecessary copy. Optimize this.
        qkv = torch.stack([query, key, value], dim=1)
        out = self.flash_attn(
            qkv,
            cu_seqlens=prefix_sum,
            max_s=max_prompt_len,
            causal=True,
        )[0]
        # FIXME(woosuk): Unnecessary copy. Optimize this.
Woosuk Kwon's avatar
Woosuk Kwon committed
51
        output.copy_(out, non_blocking=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54

    def single_query_cached_kv_attention(
        self,
55
56
57
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
58
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
        input_metadata: InputMetadata,
    ) -> None:
61
62
63
64
65
66
67
68
        head_size = value_cache.shape[2]
        supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
        if head_size not in supported_head_sizes:
            raise ValueError(f'head_size ({head_size}) is not supported by '
                             'the single_query_cached_kv_attention kernel. '
                             'Use one of the following head sizes: '
                             f'{supported_head_sizes}.')

69
70
71
72
73
74
75
76
77
78
79
80
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83

    def forward(
        self,
84
85
86
87
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
88
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
91
92
93
94
95
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Pre-allocate the output tensor.
        output = torch.empty_like(query)

        # Prune out paddings if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
        query = query[:input_metadata.num_valid_tokens]
        key = key[:input_metadata.num_valid_tokens]
        value = value[:input_metadata.num_valid_tokens]

Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
        # Reshape the input tensors.
        num_heads = value_cache.shape[1]
102
        head_size = value_cache.shape[2]
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)
106
        output = output.view(-1, num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108

        # Compute the attention op for prompts.
109
        if input_metadata.num_prompts > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
110
            num_prompt_tokens = sum(input_metadata.prompt_lens)
111
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116
117
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
                value[:num_prompt_tokens],               
                input_metadata.prompt_lens,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
121
122
123

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
124
        cache_ops.reshape_and_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
            key, value, key_cache, value_cache, input_metadata.slot_mapping)

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
129
            start_idx = sum(input_metadata.prompt_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134
135
136
137
            self.single_query_cached_kv_attention(
                output[start_idx:],
                query[start_idx:],
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
138
        # NOTE(woosuk): The output tensor may include paddings.
Woosuk Kwon's avatar
Woosuk Kwon committed
139
        return output.view(-1, num_heads * head_size)