attention.py 4.86 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15

import torch
import torch.nn as nn

from cacheflow import ops
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

    def __init__(self, scale: float) -> None:
        super().__init__()
        self.scale = scale

Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    def _masked_attention(
        self,
        query: torch.Tensor,                        # [num_queries, num_heads, head_size]
        key: torch.Tensor,                          # [num_keys, num_heads, head_size]
        value: torch.Tensor,                        # [num_keys, num_heads, head_size]
        attn_mask: Optional[torch.Tensor] = None,   # [num_queries, num_keys]
    ) -> torch.Tensor:                              # [num_queries, num_heads, head_size]
        query = query * self.scale
        attn = torch.einsum('qhd,khd->hqk', query, key)
        if attn_mask is not None:
            attn = attn + attn_mask
        attn = torch.softmax(attn, dim=-1)
        out = torch.einsum('hqk,khd->qhd', attn, value)
        return out
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33
34
35
36
37

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
41
42
        # FIXME(woosuk): Replace this with a custom op call.
        attention_mask = torch.triu(
            torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5
        attention_mask = attention_mask.to(dtype=query.dtype, device=query.device)
        out = self._masked_attention(query, key, value, attention_mask)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        output.copy_(out, non_blocking=True)

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        num_heads = value_cache.shape[1]
        head_size = value_cache.shape[3]
        block_size = value_cache.shape[2]
        block_tables = input_metadata.block_tables

        # FIXME(woosuk): Replace the following with a custom op.
        for i in range(input_metadata.num_generation_tokens):
Woosuk Kwon's avatar
Woosuk Kwon committed
60
            q = query[i].unsqueeze(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
            block_table = block_tables[i]
            context_len = int(input_metadata.context_lens[i])
Woosuk Kwon's avatar
Woosuk Kwon committed
63

Woosuk Kwon's avatar
Woosuk Kwon committed
64
            keys = []
Woosuk Kwon's avatar
Woosuk Kwon committed
65
            values = []
Woosuk Kwon's avatar
Woosuk Kwon committed
66
            for j in range(context_len):
Woosuk Kwon's avatar
Woosuk Kwon committed
67
                block_number = int(block_table[j // block_size])
Woosuk Kwon's avatar
Woosuk Kwon committed
68
                block_offset = j % block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
69

Woosuk Kwon's avatar
Woosuk Kwon committed
70
                k = key_cache[block_number, :, :, block_offset, :]
Woosuk Kwon's avatar
Woosuk Kwon committed
71
                k = k.reshape(num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
74
75
                keys.append(k)

                v = value_cache[block_number, :, block_offset, :]
                values.append(v)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
            keys = torch.stack(keys, dim=0)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
            values = torch.stack(values, dim=0)

Woosuk Kwon's avatar
Woosuk Kwon committed
79
            out = self._masked_attention(q, keys, values)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
            out = out.view(num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
85
86
87
88
89
90
91
92
            output[i].copy_(out, non_blocking=True)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
96
97
        # Prune out invalid tokens.
        query = query[:input_metadata.num_valid_tokens]
        key = key[:input_metadata.num_valid_tokens]
        value = value[:input_metadata.num_valid_tokens]

Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        # Reshape the input tensors.
        num_heads = value_cache.shape[1]
        head_size = value_cache.shape[3]
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)

        # Compute the attention op for prompts.
        output = torch.empty_like(query)
        start_idx = 0
        for i in range(input_metadata.num_prompts):
            prompt_len = input_metadata.prompt_lens[i]
            out = output[start_idx:start_idx + prompt_len]
            q = query[start_idx:start_idx + prompt_len]
            k = key[start_idx:start_idx + prompt_len]
            v = value[start_idx:start_idx + prompt_len]
            self.multi_query_kv_attention(out, q, k, v)
            start_idx += prompt_len

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
        ops.reshape_and_cache(
            key, value, key_cache, value_cache, input_metadata.slot_mapping)

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
                output[start_idx:],
                query[start_idx:],
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
        return output.view(-1, num_heads * head_size)