"deploy/vscode:/vscode.git/clone" did not exist on "e99565c7f35d495c07db3b08a0d06c44facd76e9"
Commit f8f7568d authored by comfyanonymous's avatar comfyanonymous
Browse files

Basic SD3 controlnet implementation.

Still missing the node to properly use it.
parent 66aaa140
import torch
from typing import Dict, Optional
import comfy.ldm.modules.diffusionmodules.mmdit
import comfy.latent_formats
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__(
self,
num_blocks = None,
dtype = None,
device = None,
operations = None,
**kwargs,
):
super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
# controlnet_blocks
self.controlnet_blocks = torch.nn.ModuleList([])
for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None,
self.patch_size,
self.in_channels,
self.hidden_size,
bias=True,
strict_img_size=False,
dtype=dtype,
device=device,
operations=operations
)
self.latent_format = comfy.latent_formats.SD3()
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> torch.Tensor:
#weird sd3 controlnet specific stuff
hint = hint * self.latent_format.scale_factor # self.latent_format.process_in(hint)
y = torch.zeros_like(y)
if self.context_processor is not None:
context = self.context_processor(context)
hw = x.shape[-2:]
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
x += self.pos_embed_input(hint)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y)
c = c + y
if context is not None:
context = self.context_embedder(context)
if self.register_length > 0:
context = torch.cat(
(
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
default(context, torch.Tensor([]).type_as(x)),
),
1,
)
output = []
blocks = len(self.joint_blocks)
for i in range(blocks):
context, x = self.joint_blocks[i](
context,
x,
c=c,
use_checkpoint=self.use_checkpoint,
)
out = self.controlnet_blocks[i](x)
count = self.depth // blocks
if i == blocks - 1:
count -= 1
for j in range(count):
output.append(out)
return {"output": output}
...@@ -11,6 +11,7 @@ import comfy.ops ...@@ -11,6 +11,7 @@ import comfy.ops
import comfy.cldm.cldm import comfy.cldm.cldm
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
...@@ -94,13 +95,17 @@ class ControlBase: ...@@ -94,13 +95,17 @@ class ControlBase:
for key in control: for key in control:
control_output = control[key] control_output = control[key]
applied_to = set()
for i in range(len(control_output)): for i in range(len(control_output)):
x = control_output[i] x = control_output[i]
if x is not None: if x is not None:
if self.global_average_pooling: if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x)
x *= self.strength
if x.dtype != output_dtype: if x.dtype != output_dtype:
x = x.to(output_dtype) x = x.to(output_dtype)
...@@ -120,17 +125,18 @@ class ControlBase: ...@@ -120,17 +125,18 @@ class ControlBase:
if o[i].shape[0] < prev_val.shape[0]: if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i] o[i] = prev_val + o[i]
else: else:
o[i] += prev_val o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
if control_model is not None: if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
...@@ -308,6 +314,37 @@ class ControlLora(ControlNet): ...@@ -308,6 +314,37 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
supported_inference_dtypes = model_config.supported_inference_dtypes
controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
control = ControlNet(control_model, compression_ratio=1, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data: if "lora_controlnet" in controlnet_data:
...@@ -360,6 +397,8 @@ def load_controlnet(ckpt_path, model=None): ...@@ -360,6 +397,8 @@ def load_controlnet(ckpt_path, model=None):
if len(leftover_keys) > 0: if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys)) logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
return load_controlnet_mmdit(controlnet_data)
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
pth = False pth = False
......
...@@ -745,6 +745,8 @@ class MMDiT(nn.Module): ...@@ -745,6 +745,8 @@ class MMDiT(nn.Module):
qkv_bias: bool = True, qkv_bias: bool = True,
context_processor_layers = None, context_processor_layers = None,
context_size = 4096, context_size = 4096,
num_blocks = None,
final_layer = True,
dtype = None, #TODO dtype = None, #TODO
device = None, device = None,
operations = None, operations = None,
...@@ -766,7 +768,10 @@ class MMDiT(nn.Module): ...@@ -766,7 +768,10 @@ class MMDiT(nn.Module):
# apply magic --> this defines a head_size of 64 # apply magic --> this defines a head_size of 64
self.hidden_size = 64 * depth self.hidden_size = 64 * depth
num_heads = depth num_heads = depth
if num_blocks is None:
num_blocks = depth
self.depth = depth
self.num_heads = num_heads self.num_heads = num_heads
self.x_embedder = PatchEmbed( self.x_embedder = PatchEmbed(
...@@ -821,7 +826,7 @@ class MMDiT(nn.Module): ...@@ -821,7 +826,7 @@ class MMDiT(nn.Module):
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
attn_mode=attn_mode, attn_mode=attn_mode,
pre_only=i == depth - 1, pre_only=(i == num_blocks - 1) and final_layer,
rmsnorm=rmsnorm, rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only, scale_mod_only=scale_mod_only,
swiglu=swiglu, swiglu=swiglu,
...@@ -830,11 +835,12 @@ class MMDiT(nn.Module): ...@@ -830,11 +835,12 @@ class MMDiT(nn.Module):
device=device, device=device,
operations=operations operations=operations
) )
for i in range(depth) for i in range(num_blocks)
] ]
) )
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) if final_layer:
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
if compile_core: if compile_core:
assert False assert False
...@@ -893,6 +899,7 @@ class MMDiT(nn.Module): ...@@ -893,6 +899,7 @@ class MMDiT(nn.Module):
x: torch.Tensor, x: torch.Tensor,
c_mod: torch.Tensor, c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
control = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.register_length > 0: if self.register_length > 0:
context = torch.cat( context = torch.cat(
...@@ -905,13 +912,20 @@ class MMDiT(nn.Module): ...@@ -905,13 +912,20 @@ class MMDiT(nn.Module):
# context is B, L', D # context is B, L', D
# x is B, L, D # x is B, L, D
for block in self.joint_blocks: blocks = len(self.joint_blocks)
context, x = block( for i in range(blocks):
context, x = self.joint_blocks[i](
context, context,
x, x,
c=c_mod, c=c_mod,
use_checkpoint=self.use_checkpoint, use_checkpoint=self.use_checkpoint,
) )
if control is not None:
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
x += add
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
return x return x
...@@ -922,6 +936,7 @@ class MMDiT(nn.Module): ...@@ -922,6 +936,7 @@ class MMDiT(nn.Module):
t: torch.Tensor, t: torch.Tensor,
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
control = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass of DiT. Forward pass of DiT.
...@@ -943,7 +958,7 @@ class MMDiT(nn.Module): ...@@ -943,7 +958,7 @@ class MMDiT(nn.Module):
if context is not None: if context is not None:
context = self.context_embedder(context) context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context) x = self.forward_core_with_concat(x, c, context, control)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]] return x[:,:,:hw[-2],:hw[-1]]
...@@ -956,7 +971,8 @@ class OpenAISignatureMMDITWrapper(MMDiT): ...@@ -956,7 +971,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
timesteps: torch.Tensor, timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
control = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y) return super().forward(x, timesteps, context=context, y=y, control=control)
...@@ -41,7 +41,9 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -41,7 +41,9 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1] unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2] patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
unet_config["patch_size"] = patch_size unet_config["patch_size"] = patch_size
unet_config["out_channels"] = state_dict['{}final_layer.linear.weight'.format(key_prefix)].shape[0] // (patch_size * patch_size) final_layer = '{}final_layer.linear.weight'.format(key_prefix)
if final_layer in state_dict:
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64 unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
unet_config["input_size"] = None unet_config["input_size"] = None
...@@ -435,10 +437,11 @@ def model_config_from_diffusers_unet(state_dict): ...@@ -435,10 +437,11 @@ def model_config_from_diffusers_unet(state_dict):
return None return None
def convert_diffusers_mmdit(state_dict, output_prefix=""): def convert_diffusers_mmdit(state_dict, output_prefix=""):
depth = count_blocks(state_dict, 'transformer_blocks.{}.') num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
if depth > 0: if num_blocks > 0:
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
out_sd = {} out_sd = {}
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix) sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
for k in sd_map: for k in sd_map:
weight = state_dict.get(k, None) weight = state_dict.get(k, None)
if weight is not None: if weight is not None:
......
...@@ -298,7 +298,8 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""): ...@@ -298,7 +298,8 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
key_map = {} key_map = {}
depth = mmdit_config.get("depth", 0) depth = mmdit_config.get("depth", 0)
for i in range(depth): num_blocks = mmdit_config.get("num_blocks", depth)
for i in range(num_blocks):
block_from = "transformer_blocks.{}".format(i) block_from = "transformer_blocks.{}".format(i)
block_to = "{}joint_blocks.{}".format(output_prefix, i) block_to = "{}joint_blocks.{}".format(output_prefix, i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment