"...git@developer.sourcefind.cn:dcuai/dlexamples.git" did not exist on "c1476851640042557813c92139ef62e3cd01af56"
Commit faa57430 authored by comfyanonymous's avatar comfyanonymous
Browse files

Controlnet union model basic implementation.

This is only the model code itself, it currently defaults to an empty
embedding [0] * 6 which seems to work better than treating it like a
regular controlnet.

TODO: Add nodes to select the image type.
parent bb663bcd
...@@ -13,7 +13,47 @@ from ..ldm.modules.diffusionmodules.util import ( ...@@ -13,7 +13,47 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists from ..ldm.util import exists
from ..ldm.cascade.common import OptimizedAttention
from collections import OrderedDict
import comfy.ops import comfy.ops
from comfy.ldm.modules.attention import optimized_attention
class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.c = c
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.in_proj(x)
q, k, v = x.split(self.c, dim=2)
out = optimized_attention(q, k, v, self.heads)
return self.out_proj(out)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResBlockUnionControlnet(nn.Module):
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
self.mlp = nn.Sequential(
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
def attention(self, x: torch.Tensor):
return self.attn(x)
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):
#implemented in the ldm unet #implemented in the ldm unet
...@@ -53,6 +93,7 @@ class ControlNet(nn.Module): ...@@ -53,6 +93,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None, transformer_depth_middle=None,
transformer_depth_output=None, transformer_depth_output=None,
attn_precision=None, attn_precision=None,
union_controlnet=False,
device=None, device=None,
operations=comfy.ops.disable_weight_init, operations=comfy.ops.disable_weight_init,
**kwargs, **kwargs,
...@@ -280,6 +321,65 @@ class ControlNet(nn.Module): ...@@ -280,6 +321,65 @@ class ControlNet(nn.Module):
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch self._feature_size += ch
if union_controlnet:
self.num_control_type = 6
num_trans_channel = 320
num_trans_head = 8
num_trans_layer = 1
num_proj_channel = 320
# task_scale_factor = num_trans_channel ** 0.5
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
#-----------------------------------------------------------------------------------------------------
control_add_embed_dim = 256
class ControlAddEmbedding(nn.Module):
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
super().__init__()
self.num_control_type = num_control_type
self.in_dim = in_dim
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
def forward(self, control_type, dtype, device):
c_type = torch.zeros((self.num_control_type,), device=device)
c_type[control_type] = 1.0
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
else:
self.task_embedding = None
self.control_add_embedding = None
def union_controlnet_merge(self, hint, control_type, emb, context):
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
inputs = []
condition_list = []
for idx in range(min(1, len(control_type))):
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
if idx < len(control_type):
feat_seq += self.task_embedding[control_type[idx]]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(controlnet_cond)
x = torch.cat(inputs, dim=1)
x = self.transformer_layes(x)
controlnet_cond_fuser = None
for idx in range(len(control_type)):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
o = condition_list[idx] + alpha
if controlnet_cond_fuser is None:
controlnet_cond_fuser = o
else:
controlnet_cond_fuser += o
return controlnet_cond_fuser
def make_zero_conv(self, channels, operations=None, dtype=None, device=None): def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
...@@ -287,7 +387,18 @@ class ControlNet(nn.Module): ...@@ -287,7 +387,18 @@ class ControlNet(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context) guided_hint = None
if self.control_add_embedding is not None:
control_type = kwargs.get("control_type", [])
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
if len(control_type) > 0:
if len(hint.shape) < 5:
hint = hint.unsqueeze(dim=0)
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
if guided_hint is None:
guided_hint = self.input_hint_block(hint, emb, context)
out_output = [] out_output = []
out_middle = [] out_middle = []
......
...@@ -413,6 +413,12 @@ def load_controlnet(ckpt_path, model=None): ...@@ -413,6 +413,12 @@ def load_controlnet(ckpt_path, model=None):
if k in controlnet_data: if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k) new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
controlnet_config["union_controlnet"] = True
for k in list(controlnet_data.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
new_sd[new_k] = controlnet_data.pop(k)
leftover_keys = controlnet_data.keys() leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0: if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys)) logging.warning("leftover keys: {}".format(leftover_keys))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment