attention.py 5.19 KB
Newer Older
1
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5

import torch
import torch.nn as nn

6
7
from cacheflow import attention_ops
from cacheflow import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
12
13
14
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

    def __init__(self, scale: float) -> None:
        super().__init__()
15
        self.scale = float(scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    def _masked_attention(
        self,
        query: torch.Tensor,                        # [num_queries, num_heads, head_size]
        key: torch.Tensor,                          # [num_keys, num_heads, head_size]
        value: torch.Tensor,                        # [num_keys, num_heads, head_size]
        attn_mask: Optional[torch.Tensor] = None,   # [num_queries, num_keys]
    ) -> torch.Tensor:                              # [num_queries, num_heads, head_size]
        query = query * self.scale
        attn = torch.einsum('qhd,khd->hqk', query, key)
        if attn_mask is not None:
            attn = attn + attn_mask
        attn = torch.softmax(attn, dim=-1)
        out = torch.einsum('hqk,khd->qhd', attn, value)
        return out
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33

    def multi_query_kv_attention(
        self,
34
35
36
37
38
        output: torch.Tensor,       # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,        # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,          # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,        # [num_prompt_tokens, num_heads, head_size]
        prompt_lens: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
39
    ) -> None:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        # FIXME(woosuk): Replace the following with a custom op.
        start_idx = 0
        for prompt_len in prompt_lens:
            out = output[start_idx:start_idx + prompt_len]
            q = query[start_idx:start_idx + prompt_len]
            k = key[start_idx:start_idx + prompt_len]
            v = value[start_idx:start_idx + prompt_len]

            attention_mask = torch.triu(
                torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
            attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
            attention_out = self._masked_attention(q, k, v, attention_mask)
            out.copy_(attention_out, non_blocking=True)

            start_idx += prompt_len
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57

    def single_query_cached_kv_attention(
        self,
58
59
60
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
61
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
        input_metadata: InputMetadata,
    ) -> None:
64
65
66
67
68
69
70
71
72
73
74
75
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78

    def forward(
        self,
79
80
81
82
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
83
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
86
87
88
89
90
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Pre-allocate the output tensor.
        output = torch.empty_like(query)

        # Prune out paddings if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94
        query = query[:input_metadata.num_valid_tokens]
        key = key[:input_metadata.num_valid_tokens]
        value = value[:input_metadata.num_valid_tokens]

Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
        # Reshape the input tensors.
        num_heads = value_cache.shape[1]
97
        head_size = value_cache.shape[2]
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)
101
        output = output.view(-1, num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103

        # Compute the attention op for prompts.
104
105
        self.multi_query_kv_attention(
            output, query, key, value, input_metadata.prompt_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
108
109
110
111

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
112
        cache_ops.reshape_and_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
            key, value, key_cache, value_cache, input_metadata.slot_mapping)

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
117
            start_idx = sum(input_metadata.prompt_lens)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
121
122
123
124
125
            self.single_query_cached_kv_attention(
                output[start_idx:],
                query[start_idx:],
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
126
        # NOTE(woosuk): The output tensor may include paddings.
Woosuk Kwon's avatar
Woosuk Kwon committed
127
        return output.view(-1, num_heads * head_size)