sub_quadratic_attention.py 9.26 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# original source:
#   https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
#   MIT
# credit:
#   Amin Rezaei (original author)
#   Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# implementation of:
#   Self-attention Does Not Need O(n2) Memory":
#   https://arxiv.org/abs/2112.05682v2

from functools import partial
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
17
import logging
edikius's avatar
edikius committed
18
19
20
21
22
23

try:
	from typing import Optional, NamedTuple, List, Protocol
except ImportError:
	from typing import Optional, NamedTuple, List
	from typing_extensions import Protocol
comfyanonymous's avatar
comfyanonymous committed
24
25
26
27

from torch import Tensor
from typing import List

28
from comfy import model_management
29

comfyanonymous's avatar
comfyanonymous committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def dynamic_slice(
    x: Tensor,
    starts: List[int],
    sizes: List[int],
) -> Tensor:
    slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
    return x[slicing]

class AttnChunk(NamedTuple):
    exp_values: Tensor
    exp_weights_sum: Tensor
    max_score: Tensor

class SummarizeChunk(Protocol):
    @staticmethod
    def __call__(
        query: Tensor,
        key_t: Tensor,
        value: Tensor,
    ) -> AttnChunk: ...

class ComputeQueryChunkAttn(Protocol):
    @staticmethod
    def __call__(
        query: Tensor,
        key_t: Tensor,
        value: Tensor,
    ) -> Tensor: ...

def _summarize_chunk(
    query: Tensor,
    key_t: Tensor,
    value: Tensor,
    scale: float,
64
    upcast_attention: bool,
65
    mask,
comfyanonymous's avatar
comfyanonymous committed
66
) -> AttnChunk:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    if upcast_attention:
        with torch.autocast(enabled=False, device_type = 'cuda'):
            query = query.float()
            key_t = key_t.float()
            attn_weights = torch.baddbmm(
                torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
                query,
                key_t,
                alpha=scale,
                beta=0,
            )
    else:
        attn_weights = torch.baddbmm(
            torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
            query,
            key_t,
            alpha=scale,
            beta=0,
        )
comfyanonymous's avatar
comfyanonymous committed
86
87
    max_score, _ = torch.max(attn_weights, -1, keepdim=True)
    max_score = max_score.detach()
88
    attn_weights -= max_score
89
90
    if mask is not None:
        attn_weights += mask
91
    torch.exp(attn_weights, out=attn_weights)
92
    exp_weights = attn_weights.to(value.dtype)
comfyanonymous's avatar
comfyanonymous committed
93
94
95
96
97
98
99
100
101
102
    exp_values = torch.bmm(exp_weights, value)
    max_score = max_score.squeeze(-1)
    return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)

def _query_chunk_attention(
    query: Tensor,
    key_t: Tensor,
    value: Tensor,
    summarize_chunk: SummarizeChunk,
    kv_chunk_size: int,
103
    mask,
comfyanonymous's avatar
comfyanonymous committed
104
105
106
107
) -> Tensor:
    batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
    _, _, v_channels_per_head = value.shape

108
    def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
comfyanonymous's avatar
comfyanonymous committed
109
110
111
112
113
114
115
116
117
118
        key_chunk = dynamic_slice(
            key_t,
            (0, 0, chunk_idx),
            (batch_x_heads, k_channels_per_head, kv_chunk_size)
        )
        value_chunk = dynamic_slice(
            value,
            (0, chunk_idx, 0),
            (batch_x_heads, kv_chunk_size, v_channels_per_head)
        )
119
120
121
122
        if mask is not None:
            mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]

        return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
comfyanonymous's avatar
comfyanonymous committed
123
124

    chunks: List[AttnChunk] = [
125
        chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
comfyanonymous's avatar
comfyanonymous committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    ]
    acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
    chunk_values, chunk_weights, chunk_max = acc_chunk

    global_max, _ = torch.max(chunk_max, 0, keepdim=True)
    max_diffs = torch.exp(chunk_max - global_max)
    chunk_values *= torch.unsqueeze(max_diffs, -1)
    chunk_weights *= max_diffs

    all_values = chunk_values.sum(dim=0)
    all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
    return all_values / all_weights

# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
    query: Tensor,
    key_t: Tensor,
    value: Tensor,
    scale: float,
145
    upcast_attention: bool,
146
    mask,
comfyanonymous's avatar
comfyanonymous committed
147
) -> Tensor:
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    if upcast_attention:
        with torch.autocast(enabled=False, device_type = 'cuda'):
            query = query.float()
            key_t = key_t.float()
            attn_scores = torch.baddbmm(
                torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
                query,
                key_t,
                alpha=scale,
                beta=0,
            )
    else:
        attn_scores = torch.baddbmm(
            torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
            query,
            key_t,
            alpha=scale,
            beta=0,
        )
167

168
169
    if mask is not None:
        attn_scores += mask
170
171
172
    try:
        attn_probs = attn_scores.softmax(dim=-1)
        del attn_scores
173
    except model_management.OOM_EXCEPTION:
174
        logging.warning("ran out of memory while running softmax in  _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
175
        attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
176
177
178
179
180
        torch.exp(attn_scores, out=attn_scores)
        summed = torch.sum(attn_scores, dim=-1, keepdim=True)
        attn_scores /= summed
        attn_probs = attn_scores

181
    hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
comfyanonymous's avatar
comfyanonymous committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    return hidden_states_slice

class ScannedChunk(NamedTuple):
    chunk_idx: int
    attn_chunk: AttnChunk

def efficient_dot_product_attention(
    query: Tensor,
    key_t: Tensor,
    value: Tensor,
    query_chunk_size=1024,
    kv_chunk_size: Optional[int] = None,
    kv_chunk_size_min: Optional[int] = None,
    use_checkpoint=True,
196
    upcast_attention=False,
197
    mask = None,
comfyanonymous's avatar
comfyanonymous committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
):
    """Computes efficient dot-product attention given query, transposed key, and value.
      This is efficient version of attention presented in
      https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
      Args:
        query: queries for calculating attention with shape of
          `[batch * num_heads, tokens, channels_per_head]`.
        key_t: keys for calculating attention with shape of
          `[batch * num_heads, channels_per_head, tokens]`.
        value: values to be used in attention with shape of
          `[batch * num_heads, tokens, channels_per_head]`.
        query_chunk_size: int: query chunks size
        kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
        kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
        use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
      Returns:
        Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
      """
    batch_x_heads, q_tokens, q_channels_per_head = query.shape
    _, _, k_tokens = key_t.shape
    scale = q_channels_per_head ** -0.5

    kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
    if kv_chunk_size_min is not None:
        kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)

224
225
226
    if mask is not None and len(mask.shape) == 2:
        mask = mask.unsqueeze(0)

comfyanonymous's avatar
comfyanonymous committed
227
228
229
230
231
232
    def get_query_chunk(chunk_idx: int) -> Tensor:
        return dynamic_slice(
            query,
            (0, chunk_idx, 0),
            (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
        )
233
234
235
236
237
238
239

    def get_mask_chunk(chunk_idx: int) -> Tensor:
        if mask is None:
            return None
        chunk = min(query_chunk_size, q_tokens)
        return mask[:,chunk_idx:chunk_idx + chunk]

240
    summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
comfyanonymous's avatar
comfyanonymous committed
241
242
243
    summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
    compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
        _get_attention_scores_no_kv_chunking,
244
245
        scale=scale,
        upcast_attention=upcast_attention
comfyanonymous's avatar
comfyanonymous committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    ) if k_tokens <= kv_chunk_size else (
        # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
        partial(
            _query_chunk_attention,
            kv_chunk_size=kv_chunk_size,
            summarize_chunk=summarize_chunk,
        )
    )

    if q_tokens <= query_chunk_size:
        # fast-path for when there's just 1 query chunk
        return compute_query_chunk_attn(
            query=query,
            key_t=key_t,
            value=value,
261
            mask=mask,
comfyanonymous's avatar
comfyanonymous committed
262
263
264
265
266
267
268
269
270
        )
    
    # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
    # and pass slices to be mutated, instead of torch.cat()ing the returned slices
    res = torch.cat([
        compute_query_chunk_attn(
            query=get_query_chunk(i * query_chunk_size),
            key_t=key_t,
            value=value,
271
            mask=get_mask_chunk(i * query_chunk_size)
comfyanonymous's avatar
comfyanonymous committed
272
273
274
        ) for i in range(math.ceil(q_tokens / query_chunk_size))
    ], dim=1)
    return res