"magic_pdf/vscode:/vscode.git/clone" did not exist on "12bec17eed726a9f193cbb1f3512df73582a0b82"
attention.py 4.21 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Optional, Tuple

import torch
import torch.nn as nn
import xformers.ops as xops

from cacheflow import ops
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

    def __init__(self, scale: float) -> None:
        super().__init__()
        self.scale = scale

        # Shape-agnostic attention mask.
        self.attention_mask = xops.LowerTriangularMask()

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> None:
        out = xops.memory_efficient_attention(
            query, key, value, attn_bias=self.attention_mask, scale=self.scale)
        # FIXME(woosuk): Directly write the attention output.
        output.copy_(out, non_blocking=True)

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        num_heads = value_cache.shape[1]
        head_size = value_cache.shape[3]
        block_size = value_cache.shape[2]
        block_tables = input_metadata.block_tables

        # FIXME(woosuk): Replace the following with a custom op.
        for i in range(input_metadata.num_generation_tokens):
Woosuk Kwon's avatar
Woosuk Kwon committed
47
            q = query[i].unsqueeze(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
            block_table = block_tables[i]
            context_len = int(input_metadata.context_lens[i])
Woosuk Kwon's avatar
Woosuk Kwon committed
50

Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
54
55
            keys = []
            for j in range(context_len):
                block_number = block_table[j // block_size]
                block_offset = j % block_size
                k = key_cache[block_number, :, :, block_offset, :]
Woosuk Kwon's avatar
Woosuk Kwon committed
56
                k = k.reshape(num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
57
                keys.append(k)
Woosuk Kwon's avatar
Woosuk Kwon committed
58
            keys = torch.stack(keys, dim=0)
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
63
64
65

            values = []
            for j in range(context_len):
                block_number = block_table[j // block_size]
                block_offset = j % block_size
                v = value_cache[block_number, :, block_offset, :]
                values.append(v)
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
69
70
71
72
73
            values = torch.stack(values, dim=0)

            q = q.unsqueeze(0)
            keys = keys.unsqueeze(0)
            values = values.unsqueeze(0)
            out = xops.memory_efficient_attention(
                q, keys, values, scale=self.scale)
            out = out.view(num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            output[i].copy_(out, non_blocking=True)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        # Reshape the input tensors.
        num_heads = value_cache.shape[1]
        head_size = value_cache.shape[3]
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)

        # Compute the attention op for prompts.
        output = torch.empty_like(query)
        start_idx = 0
        for i in range(input_metadata.num_prompts):
            prompt_len = input_metadata.prompt_lens[i]
            out = output[start_idx:start_idx + prompt_len]
            q = query[start_idx:start_idx + prompt_len]
            k = key[start_idx:start_idx + prompt_len]
            v = value[start_idx:start_idx + prompt_len]
            self.multi_query_kv_attention(out, q, k, v)
            start_idx += prompt_len

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
        ops.reshape_and_cache(
            key, value, key_cache, value_cache, input_metadata.slot_mapping)

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
                output[start_idx:],
                query[start_idx:],
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
        return output.view(-1, num_heads * head_size)