attention.py 4.39 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

import torch
import torch.nn as nn
import xformers.ops as xops

from cacheflow import ops
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

    def __init__(self, scale: float) -> None:
        super().__init__()
        self.scale = scale

        # Shape-agnostic attention mask.
        self.attention_mask = xops.LowerTriangularMask()

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
        query = query.unsqueeze(0)
        key = key.unsqueeze(0)
        value = value.unsqueeze(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
        out = xops.memory_efficient_attention(
            query, key, value, attn_bias=self.attention_mask, scale=self.scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        out = out.squeeze(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        # FIXME(woosuk): Directly write the attention output.
        output.copy_(out, non_blocking=True)

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        num_heads = value_cache.shape[1]
        head_size = value_cache.shape[3]
        block_size = value_cache.shape[2]
        block_tables = input_metadata.block_tables

        # FIXME(woosuk): Replace the following with a custom op.
        for i in range(input_metadata.num_generation_tokens):
Woosuk Kwon's avatar
Woosuk Kwon committed
51
            q = query[i].unsqueeze(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
            block_table = block_tables[i]
            context_len = int(input_metadata.context_lens[i])
Woosuk Kwon's avatar
Woosuk Kwon committed
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
            keys = []
Woosuk Kwon's avatar
Woosuk Kwon committed
56
            values = []
Woosuk Kwon's avatar
Woosuk Kwon committed
57
            for j in range(context_len):
Woosuk Kwon's avatar
Woosuk Kwon committed
58
                block_number = int(block_table[j // block_size])
Woosuk Kwon's avatar
Woosuk Kwon committed
59
                block_offset = j % block_size
Woosuk Kwon's avatar
Woosuk Kwon committed
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61
                k = key_cache[block_number, :, :, block_offset, :]
Woosuk Kwon's avatar
Woosuk Kwon committed
62
                k = k.reshape(num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
66
                keys.append(k)

                v = value_cache[block_number, :, block_offset, :]
                values.append(v)
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68

            keys = torch.stack(keys, dim=0)
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
71
72
73
74
75
76
            values = torch.stack(values, dim=0)

            q = q.unsqueeze(0)
            keys = keys.unsqueeze(0)
            values = values.unsqueeze(0)
            out = xops.memory_efficient_attention(
                q, keys, values, scale=self.scale)
            out = out.view(num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
81
82
83
84
85
86
87
88
            output[i].copy_(out, non_blocking=True)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92
93
        # Prune out invalid tokens.
        query = query[:input_metadata.num_valid_tokens]
        key = key[:input_metadata.num_valid_tokens]
        value = value[:input_metadata.num_valid_tokens]

Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        # Reshape the input tensors.
        num_heads = value_cache.shape[1]
        head_size = value_cache.shape[3]
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)

        # Compute the attention op for prompts.
        output = torch.empty_like(query)
        start_idx = 0
        for i in range(input_metadata.num_prompts):
            prompt_len = input_metadata.prompt_lens[i]
            out = output[start_idx:start_idx + prompt_len]
            q = query[start_idx:start_idx + prompt_len]
            k = key[start_idx:start_idx + prompt_len]
            v = value[start_idx:start_idx + prompt_len]
            self.multi_query_kv_attention(out, q, k, v)
            start_idx += prompt_len

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
        ops.reshape_and_cache(
            key, value, key_cache, value_cache, input_metadata.slot_mapping)

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
                output[start_idx:],
                query[start_idx:],
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
        return output.view(-1, num_heads * head_size)