attention.py 5.36 KB
Newer Older
1
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2

3
from flash_attn.flash_attention import FlashAttention
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
import torch
import torch.nn as nn

7
8
from cacheflow import attention_ops
from cacheflow import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
11
12
13
14
15
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

    def __init__(self, scale: float) -> None:
        super().__init__()
16
        self.scale = float(scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
        self.flash_attn = FlashAttention(softmax_scale=self.scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
21

    def multi_query_kv_attention(
        self,
22
23
24
25
26
        output: torch.Tensor,       # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,        # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,          # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,        # [num_prompt_tokens, num_heads, head_size]
        prompt_lens: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
27
    ) -> None:
28
29
30
31
32
33
34
35
36
        if query.dtype == torch.float:
            raise ValueError('The float data type is not supported by '
                             'FlashAttention. Use the half data type instead.')
        head_size = query.shape[2]
        if head_size > 128:
            raise ValueError('FlashAttention does not support head_size > 128.')

        device = query.device
        prefix_sum = [0]
37
        for prompt_len in prompt_lens:
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
            prefix_sum.append(prefix_sum[-1] + prompt_len)
        prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
        max_prompt_len = max(prompt_lens)

        # FIXME(woosuk): Unnecessary copy. Optimize this.
        qkv = torch.stack([query, key, value], dim=1)
        out = self.flash_attn(
            qkv,
            cu_seqlens=prefix_sum,
            max_s=max_prompt_len,
            causal=True,
        )[0]
        num_tokens = prefix_sum[-1]
        # FIXME(woosuk): Unnecessary copy. Optimize this.
        output[:num_tokens].copy_(out, non_blocking=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55

    def single_query_cached_kv_attention(
        self,
56
57
58
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
59
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
        input_metadata: InputMetadata,
    ) -> None:
62
63
64
65
66
67
68
69
        head_size = value_cache.shape[2]
        supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
        if head_size not in supported_head_sizes:
            raise ValueError(f'head_size ({head_size}) is not supported by '
                             'the single_query_cached_kv_attention kernel. '
                             'Use one of the following head sizes: '
                             f'{supported_head_sizes}.')

70
71
72
73
74
75
76
77
78
79
80
81
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84

    def forward(
        self,
85
86
87
88
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
89
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
92
93
94
95
96
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Pre-allocate the output tensor.
        output = torch.empty_like(query)

        # Prune out paddings if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
        query = query[:input_metadata.num_valid_tokens]
        key = key[:input_metadata.num_valid_tokens]
        value = value[:input_metadata.num_valid_tokens]

Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
        # Reshape the input tensors.
        num_heads = value_cache.shape[1]
103
        head_size = value_cache.shape[2]
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)
107
        output = output.view(-1, num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109

        # Compute the attention op for prompts.
110
111
112
        if input_metadata.num_prompts > 0:
            self.multi_query_kv_attention(
                output, query, key, value, input_metadata.prompt_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
117
118

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
119
        cache_ops.reshape_and_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
122
123
            key, value, key_cache, value_cache, input_metadata.slot_mapping)

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
124
            start_idx = sum(input_metadata.prompt_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
129
130
131
132
            self.single_query_cached_kv_attention(
                output[start_idx:],
                query[start_idx:],
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
133
        # NOTE(woosuk): The output tensor may include paddings.
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        return output.view(-1, num_heads * head_size)