attention.py 20.1 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
8
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any

comfyanonymous's avatar
comfyanonymous committed
9
from .diffusionmodules.util import checkpoint
comfyanonymous's avatar
comfyanonymous committed
10
11
from .sub_quadratic_attention import efficient_dot_product_attention

12
from comfy import model_management
13

14
if model_management.xformers_enabled():
comfyanonymous's avatar
comfyanonymous committed
15
16
17
    import xformers
    import xformers.ops

comfyanonymous's avatar
comfyanonymous committed
18
from comfy.cli_args import args
comfyanonymous's avatar
comfyanonymous committed
19
20
import comfy.ops

comfyanonymous's avatar
comfyanonymous committed
21
# CrossAttn precision handling
comfyanonymous's avatar
comfyanonymous committed
22
23
24
25
26
if args.dont_upcast_attention:
    print("disabling upcasting of attention")
    _ATTN_PRECISION = "fp16"
else:
    _ATTN_PRECISION = "fp32"
comfyanonymous's avatar
comfyanonymous committed
27

28

comfyanonymous's avatar
comfyanonymous committed
29
30
31
32
33
34
35
36
37
38
39
def exists(val):
    return val is not None


def uniq(arr):
    return{el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
40
    return d
comfyanonymous's avatar
comfyanonymous committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
comfyanonymous's avatar
comfyanonymous committed
56
    def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
comfyanonymous's avatar
comfyanonymous committed
57
        super().__init__()
comfyanonymous's avatar
comfyanonymous committed
58
        self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
59
60
61
62
63
64
65

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
comfyanonymous's avatar
comfyanonymous committed
66
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
comfyanonymous's avatar
comfyanonymous committed
67
68
69
70
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
comfyanonymous's avatar
comfyanonymous committed
71
            operations.Linear(dim, inner_dim, dtype=dtype, device=device),
comfyanonymous's avatar
comfyanonymous committed
72
            nn.GELU()
comfyanonymous's avatar
comfyanonymous committed
73
        ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
comfyanonymous's avatar
comfyanonymous committed
74
75
76
77

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
comfyanonymous's avatar
comfyanonymous committed
78
            operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        )

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


94
95
def Normalize(in_channels, dtype=None, device=None):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
96

97
def attention_basic(q, k, v, heads, mask=None):
98
99
100
101
    b, _, dim_head = q.shape
    dim_head //= heads
    scale = dim_head ** -0.5

102
    h = heads
103
104
105
106
107
108
109
110
    q, k, v = map(
        lambda t: t.unsqueeze(3)
        .reshape(b, -1, heads, dim_head)
        .permute(0, 2, 1, 3)
        .reshape(b * heads, -1, dim_head)
        .contiguous(),
        (q, k, v),
    )
111
112
113
114
115
116
117
118

    # force cast to fp32 to avoid overflowing
    if _ATTN_PRECISION =="fp32":
        with torch.autocast(enabled=False, device_type = 'cuda'):
            q, k = q.float(), k.float()
            sim = einsum('b i d, b j d -> b i j', q, k) * scale
    else:
        sim = einsum('b i d, b j d -> b i j', q, k) * scale
comfyanonymous's avatar
comfyanonymous committed
119

120
    del q, k
comfyanonymous's avatar
comfyanonymous committed
121

122
123
124
125
126
    if exists(mask):
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)
comfyanonymous's avatar
comfyanonymous committed
127

128
129
    # attention, what we cannot get enough of
    sim = sim.softmax(dim=-1)
comfyanonymous's avatar
comfyanonymous committed
130

131
    out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
132
133
134
135
136
137
    out = (
        out.unsqueeze(0)
        .reshape(b, heads, -1, dim_head)
        .permute(0, 2, 1, 3)
        .reshape(b, -1, heads * dim_head)
    )
138
    return out
comfyanonymous's avatar
comfyanonymous committed
139
140


141
def attention_sub_quad(query, key, value, heads, mask=None):
142
143
144
145
146
147
148
149
    b, _, dim_head = query.shape
    dim_head //= heads

    scale = dim_head ** -0.5
    query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
    value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)

    key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
comfyanonymous's avatar
comfyanonymous committed
150

151
152
153
154
155
156
157
    dtype = query.dtype
    upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
    if upcast_attention:
        bytes_per_token = torch.finfo(torch.float32).bits//8
    else:
        bytes_per_token = torch.finfo(query.dtype).bits//8
    batch_x_heads, q_tokens, _ = query.shape
158
    _, _, k_tokens = key.shape
159
    qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
comfyanonymous's avatar
comfyanonymous committed
160

161
    mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
comfyanonymous's avatar
comfyanonymous committed
162

163
    kv_chunk_size_min = None
164
165
166
167
168
169
170
171
172
173
174
175
    kv_chunk_size = None
    query_chunk_size = None

    for x in [4096, 2048, 1024, 512, 256]:
        count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
        if count >= k_tokens:
            kv_chunk_size = k_tokens
            query_chunk_size = x
            break

    if query_chunk_size is None:
        query_chunk_size = 512
176
177
178

    hidden_states = efficient_dot_product_attention(
        query,
179
        key,
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        value,
        query_chunk_size=query_chunk_size,
        kv_chunk_size=kv_chunk_size,
        kv_chunk_size_min=kv_chunk_size_min,
        use_checkpoint=False,
        upcast_attention=upcast_attention,
    )

    hidden_states = hidden_states.to(dtype)

    hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
    return hidden_states

def attention_split(q, k, v, heads, mask=None):
194
195
196
197
    b, _, dim_head = q.shape
    dim_head //= heads
    scale = dim_head ** -0.5

198
    h = heads
199
200
201
202
203
204
205
206
    q, k, v = map(
        lambda t: t.unsqueeze(3)
        .reshape(b, -1, heads, dim_head)
        .permute(0, 2, 1, 3)
        .reshape(b * heads, -1, dim_head)
        .contiguous(),
        (q, k, v),
    )
207
208
209
210
211

    r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)

    mem_free_total = model_management.get_free_memory(q.device)

212
213
214
215
216
    if _ATTN_PRECISION =="fp32":
        element_size = 4
    else:
        element_size = q.element_size()

217
    gb = 1024 ** 3
218
    tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
219
    modifier = 3
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    mem_required = tensor_size * modifier
    steps = 1


    if mem_required > mem_free_total:
        steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
        # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
        #      f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")

    if steps > 64:
        max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
        raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
                            f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')

    # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
    first_op_done = False
    cleared_cache = False
    while True:
        try:
            slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
            for i in range(0, q.shape[1], slice_size):
                end = i + slice_size
                if _ATTN_PRECISION =="fp32":
                    with torch.autocast(enabled=False, device_type = 'cuda'):
                        s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
comfyanonymous's avatar
comfyanonymous committed
245
                else:
246
247
248
249
                    s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale

                s2 = s1.softmax(dim=-1).to(v.dtype)
                del s1
250
                first_op_done = True
251
252
253
254
255
256
257
258
259
260
261
262
263

                r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
                del s2
            break
        except model_management.OOM_EXCEPTION as e:
            if first_op_done == False:
                model_management.soft_empty_cache(True)
                if cleared_cache == False:
                    cleared_cache = True
                    print("out of memory error, emptying cache and trying again")
                    continue
                steps *= 2
                if steps > 64:
comfyanonymous's avatar
comfyanonymous committed
264
                    raise e
265
266
267
268
269
270
                print("out of memory error, increasing steps and trying again", steps)
            else:
                raise e

    del q, k, v

271
272
273
274
275
276
277
    r1 = (
        r1.unsqueeze(0)
        .reshape(b, heads, -1, dim_head)
        .permute(0, 2, 1, 3)
        .reshape(b, -1, heads * dim_head)
    )
    return r1
278
279

def attention_xformers(q, k, v, heads, mask=None):
280
281
282
    b, _, dim_head = q.shape
    dim_head //= heads

283
284
    q, k, v = map(
        lambda t: t.unsqueeze(3)
285
        .reshape(b, -1, heads, dim_head)
286
        .permute(0, 2, 1, 3)
287
        .reshape(b * heads, -1, dim_head)
288
289
290
291
292
293
294
295
296
297
298
        .contiguous(),
        (q, k, v),
    )

    # actually compute the attention, what we cannot get enough of
    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)

    if exists(mask):
        raise NotImplementedError
    out = (
        out.unsqueeze(0)
299
        .reshape(b, heads, -1, dim_head)
300
        .permute(0, 2, 1, 3)
301
        .reshape(b, -1, heads * dim_head)
302
303
304
305
306
307
308
309
310
311
312
    )
    return out

def attention_pytorch(q, k, v, heads, mask=None):
    b, _, dim_head = q.shape
    dim_head //= heads
    q, k, v = map(
        lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
        (q, k, v),
    )

313
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
314
315
316
317
318
    out = (
        out.transpose(1, 2).reshape(b, -1, heads * dim_head)
    )
    return out

319

320
optimized_attention = attention_basic
321
optimized_attention_masked = attention_basic
comfyanonymous's avatar
comfyanonymous committed
322

323
324
325
326
327
328
329
330
331
332
333
334
335
if model_management.xformers_enabled():
    print("Using xformers cross attention")
    optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
    print("Using pytorch cross attention")
    optimized_attention = attention_pytorch
else:
    if args.use_split_cross_attention:
        print("Using split optimization for cross attention")
        optimized_attention = attention_split
    else:
        print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
        optimized_attention = attention_sub_quad
comfyanonymous's avatar
comfyanonymous committed
336

337
338
339
if model_management.pytorch_attention_enabled():
    optimized_attention_masked = attention_pytorch

340
class CrossAttention(nn.Module):
comfyanonymous's avatar
comfyanonymous committed
341
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
342
343
344
345
346
347
348
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.heads = heads
        self.dim_head = dim_head

comfyanonymous's avatar
comfyanonymous committed
349
350
351
        self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
        self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
        self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
352

comfyanonymous's avatar
comfyanonymous committed
353
        self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
354

355
    def forward(self, x, context=None, value=None, mask=None):
356
357
358
        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
359
360
361
362
363
        if value is not None:
            v = self.to_v(value)
            del value
        else:
            v = self.to_v(context)
364

365
366
367
368
        if mask is None:
            out = optimized_attention(q, k, v, self.heads)
        else:
            out = optimized_attention_masked(q, k, v, self.heads, mask)
369
370
        return self.to_out(out)

371

comfyanonymous's avatar
comfyanonymous committed
372
373
class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
comfyanonymous's avatar
comfyanonymous committed
374
                 disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
comfyanonymous's avatar
comfyanonymous committed
375
376
        super().__init__()
        self.disable_self_attn = disable_self_attn
377
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
comfyanonymous's avatar
comfyanonymous committed
378
379
                              context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations)  # is a self-attention if not self.disable_self_attn
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
380
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
comfyanonymous's avatar
comfyanonymous committed
381
                              heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations)  # is self-attn if context is none
382
383
384
        self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
        self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
        self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
385
        self.checkpoint = checkpoint
386
387
        self.n_heads = n_heads
        self.d_head = d_head
comfyanonymous's avatar
comfyanonymous committed
388

389
390
    def forward(self, x, context=None, transformer_options={}):
        return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
comfyanonymous's avatar
comfyanonymous committed
391

392
    def _forward(self, x, context=None, transformer_options={}):
393
        extra_options = {}
394
395
        block = None
        block_index = 0
396
        if "current_index" in transformer_options:
397
398
            extra_options["transformer_index"] = transformer_options["current_index"]
        if "block_index" in transformer_options:
399
400
            block_index = transformer_options["block_index"]
            extra_options["block_index"] = block_index
401
402
        if "original_shape" in transformer_options:
            extra_options["original_shape"] = transformer_options["original_shape"]
403
404
405
        if "block" in transformer_options:
            block = transformer_options["block"]
            extra_options["block"] = block
406
407
        if "cond_or_uncond" in transformer_options:
            extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
408
409
410
411
412
        if "patches" in transformer_options:
            transformer_patches = transformer_options["patches"]
        else:
            transformer_patches = {}

413
414
415
416
417
418
419
420
        extra_options["n_heads"] = self.n_heads
        extra_options["dim_head"] = self.d_head

        if "patches_replace" in transformer_options:
            transformer_patches_replace = transformer_options["patches_replace"]
        else:
            transformer_patches_replace = {}

421
        n = self.norm1(x)
422
423
424
425
426
427
428
429
430
431
432
433
        if self.disable_self_attn:
            context_attn1 = context
        else:
            context_attn1 = None
        value_attn1 = None

        if "attn1_patch" in transformer_patches:
            patch = transformer_patches["attn1_patch"]
            if context_attn1 is None:
                context_attn1 = n
            value_attn1 = context_attn1
            for p in patch:
434
                n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
435

comfyanonymous's avatar
comfyanonymous committed
436
437
438
439
        if block is not None:
            transformer_block = (block[0], block[1], block_index)
        else:
            transformer_block = None
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        attn1_replace_patch = transformer_patches_replace.get("attn1", {})
        block_attn1 = transformer_block
        if block_attn1 not in attn1_replace_patch:
            block_attn1 = block

        if block_attn1 in attn1_replace_patch:
            if context_attn1 is None:
                context_attn1 = n
                value_attn1 = n
            n = self.attn1.to_q(n)
            context_attn1 = self.attn1.to_k(context_attn1)
            value_attn1 = self.attn1.to_v(value_attn1)
            n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
            n = self.attn1.to_out(n)
454
        else:
455
            n = self.attn1(n, context=context_attn1, value=value_attn1)
456

457
458
459
460
461
        if "attn1_output_patch" in transformer_patches:
            patch = transformer_patches["attn1_output_patch"]
            for p in patch:
                n = p(n, extra_options)

462
        x += n
463
464
465
        if "middle_patch" in transformer_patches:
            patch = transformer_patches["middle_patch"]
            for p in patch:
466
                x = p(x, extra_options)
467

468
        n = self.norm2(x)
469
470
471
472
473
474
475

        context_attn2 = context
        value_attn2 = None
        if "attn2_patch" in transformer_patches:
            patch = transformer_patches["attn2_patch"]
            value_attn2 = context_attn2
            for p in patch:
476
                n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
477

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        attn2_replace_patch = transformer_patches_replace.get("attn2", {})
        block_attn2 = transformer_block
        if block_attn2 not in attn2_replace_patch:
            block_attn2 = block

        if block_attn2 in attn2_replace_patch:
            if value_attn2 is None:
                value_attn2 = context_attn2
            n = self.attn2.to_q(n)
            context_attn2 = self.attn2.to_k(context_attn2)
            value_attn2 = self.attn2.to_v(value_attn2)
            n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
            n = self.attn2.to_out(n)
        else:
            n = self.attn2(n, context=context_attn2, value=value_attn2)
493

494
495
496
497
498
        if "attn2_output_patch" in transformer_patches:
            patch = transformer_patches["attn2_output_patch"]
            for p in patch:
                n = p(n, extra_options)

499
        x += n
comfyanonymous's avatar
comfyanonymous committed
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        x = self.ff(self.norm3(x)) + x
        return x


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    NEW: use_linear for more efficiency instead of the 1x1 convs
    """
    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None,
                 disable_self_attn=False, use_linear=False,
comfyanonymous's avatar
comfyanonymous committed
516
                 use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
comfyanonymous's avatar
comfyanonymous committed
517
518
        super().__init__()
        if exists(context_dim) and not isinstance(context_dim, list):
519
            context_dim = [context_dim] * depth
comfyanonymous's avatar
comfyanonymous committed
520
521
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
522
        self.norm = Normalize(in_channels, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
523
        if not use_linear:
comfyanonymous's avatar
comfyanonymous committed
524
            self.proj_in = operations.Conv2d(in_channels,
comfyanonymous's avatar
comfyanonymous committed
525
526
527
                                     inner_dim,
                                     kernel_size=1,
                                     stride=1,
528
                                     padding=0, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
529
        else:
comfyanonymous's avatar
comfyanonymous committed
530
            self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
531
532
533

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
comfyanonymous's avatar
comfyanonymous committed
534
                                   disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
comfyanonymous's avatar
comfyanonymous committed
535
536
537
                for d in range(depth)]
        )
        if not use_linear:
comfyanonymous's avatar
comfyanonymous committed
538
            self.proj_out = operations.Conv2d(inner_dim,in_channels,
comfyanonymous's avatar
comfyanonymous committed
539
540
                                                  kernel_size=1,
                                                  stride=1,
541
                                                  padding=0, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
542
        else:
comfyanonymous's avatar
comfyanonymous committed
543
            self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
comfyanonymous's avatar
comfyanonymous committed
544
545
        self.use_linear = use_linear

546
    def forward(self, x, context=None, transformer_options={}):
comfyanonymous's avatar
comfyanonymous committed
547
548
        # note: if no context is given, cross-attention defaults to self-attention
        if not isinstance(context, list):
549
            context = [context] * len(self.transformer_blocks)
comfyanonymous's avatar
comfyanonymous committed
550
551
552
553
554
555
556
557
558
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        if not self.use_linear:
            x = self.proj_in(x)
        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
        if self.use_linear:
            x = self.proj_in(x)
        for i, block in enumerate(self.transformer_blocks):
559
            transformer_options["block_index"] = i
560
            x = block(x, context=context[i], transformer_options=transformer_options)
comfyanonymous's avatar
comfyanonymous committed
561
562
563
564
565
566
567
        if self.use_linear:
            x = self.proj_out(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
        if not self.use_linear:
            x = self.proj_out(x)
        return x + x_in