"...composable_kernel_onnx.git" did not exist on "bbcb67d0aac81b51336981713662a726875ebd58"
Commit 1a4bd9e9 authored by comfyanonymous's avatar comfyanonymous
Browse files

Refactor the attention functions.

There's no reason for the whole CrossAttention object to be repeated when
only the operation in the middle changes.
parent 8cc75c64
This diff is collapsed.
...@@ -6,7 +6,6 @@ import numpy as np ...@@ -6,7 +6,6 @@ import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention
from comfy import model_management from comfy import model_management
import comfy.ops import comfy.ops
...@@ -352,15 +351,6 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): ...@@ -352,15 +351,6 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
out = self.proj_out(out) out = self.proj_out(out)
return x+out return x+out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled_vae() and attn_type == "vanilla": if model_management.xformers_enabled_vae() and attn_type == "vanilla":
...@@ -376,9 +366,6 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): ...@@ -376,9 +366,6 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
return MemoryEfficientAttnBlock(in_channels) return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "vanilla-pytorch": elif attn_type == "vanilla-pytorch":
return MemoryEfficientAttnBlockPytorch(in_channels) return MemoryEfficientAttnBlockPytorch(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none": elif attn_type == "none":
return nn.Identity(in_channels) return nn.Identity(in_channels)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment