attention.py 4.08 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Optional, Tuple

import torch
import torch.nn as nn
import xformers.ops as xops

from cacheflow import ops
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

    def __init__(self, scale: float) -> None:
        super().__init__()
        self.scale = scale

        # Shape-agnostic attention mask.
        self.attention_mask = xops.LowerTriangularMask()

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> None:
        out = xops.memory_efficient_attention(
            query, key, value, attn_bias=self.attention_mask, scale=self.scale)
        # FIXME(woosuk): Directly write the attention output.
        output.copy_(out, non_blocking=True)

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        num_heads = value_cache.shape[1]
        head_size = value_cache.shape[3]
        block_size = value_cache.shape[2]
        block_tables = input_metadata.block_tables

        # FIXME(woosuk): Replace the following with a custom op.
        for i in range(input_metadata.num_generation_tokens):
            q = query[i]
            block_table = block_tables[i]
            context_len = int(input_metadata.context_lens[i])
            keys = []
            for j in range(context_len):
                block_number = block_table[j // block_size]
                block_offset = j % block_size
                k = key_cache[block_number, :, :, block_offset, :]
                k = k.view(num_heads, head_size)
                keys.append(k)
            keys = torch.stack(keys, dim=-1)
            logits = q @ keys
            attention_weights = torch.softmax(logits, dim=-1)

            values = []
            for j in range(context_len):
                block_number = block_table[j // block_size]
                block_offset = j % block_size
                v = value_cache[block_number, :, block_offset, :]
                values.append(v)
            values = torch.stack(values, dim=-1)
            out = attention_weights @ values
            output[i].copy_(out, non_blocking=True)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        # Reshape the input tensors.
        num_heads = value_cache.shape[1]
        head_size = value_cache.shape[3]
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)

        # Compute the attention op for prompts.
        output = torch.empty_like(query)
        start_idx = 0
        for i in range(input_metadata.num_prompts):
            prompt_len = input_metadata.prompt_lens[i]
            out = output[start_idx:start_idx + prompt_len]
            q = query[start_idx:start_idx + prompt_len]
            k = key[start_idx:start_idx + prompt_len]
            v = value[start_idx:start_idx + prompt_len]
            self.multi_query_kv_attention(out, q, k, v)
            start_idx += prompt_len

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
        ops.reshape_and_cache(
            key, value, key_cache, value_cache, input_metadata.slot_mapping)

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
                output[start_idx:],
                query[start_idx:],
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
        return output.view(-1, num_heads * head_size)