Commit 264caca2 authored by comfyanonymous's avatar comfyanonymous
Browse files

ControlNetApplySD3 node can now be used to use SD3 controlnets.

parent f8f7568d
import torch import torch
from typing import Dict, Optional from typing import Dict, Optional
import comfy.ldm.modules.diffusionmodules.mmdit import comfy.ldm.modules.diffusionmodules.mmdit
import comfy.latent_formats
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__( def __init__(
...@@ -30,8 +29,6 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): ...@@ -30,8 +29,6 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
operations=operations operations=operations
) )
self.latent_format = comfy.latent_formats.SD3()
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
...@@ -42,10 +39,8 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): ...@@ -42,10 +39,8 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
) -> torch.Tensor: ) -> torch.Tensor:
#weird sd3 controlnet specific stuff #weird sd3 controlnet specific stuff
hint = hint * self.latent_format.scale_factor # self.latent_format.process_in(hint)
y = torch.zeros_like(y) y = torch.zeros_like(y)
if self.context_processor is not None: if self.context_processor is not None:
context = self.context_processor(context) context = self.context_processor(context)
......
...@@ -7,6 +7,7 @@ import comfy.model_management ...@@ -7,6 +7,7 @@ import comfy.model_management
import comfy.model_detection import comfy.model_detection
import comfy.model_patcher import comfy.model_patcher
import comfy.ops import comfy.ops
import comfy.latent_formats
import comfy.cldm.cldm import comfy.cldm.cldm
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
...@@ -38,6 +39,8 @@ class ControlBase: ...@@ -38,6 +39,8 @@ class ControlBase:
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
self.timestep_percent_range = (0.0, 1.0) self.timestep_percent_range = (0.0, 1.0)
self.latent_format = None
self.vae = None
self.global_average_pooling = False self.global_average_pooling = False
self.timestep_range = None self.timestep_range = None
self.compression_ratio = 8 self.compression_ratio = 8
...@@ -48,10 +51,12 @@ class ControlBase: ...@@ -48,10 +51,12 @@ class ControlBase:
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
self.strength = strength self.strength = strength
self.timestep_percent_range = timestep_percent_range self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
self.vae = vae
return self return self
def pre_run(self, model, percent_to_timestep_function): def pre_run(self, model, percent_to_timestep_function):
...@@ -84,6 +89,8 @@ class ControlBase: ...@@ -84,6 +89,8 @@ class ControlBase:
c.global_average_pooling = self.global_average_pooling c.global_average_pooling = self.global_average_pooling
c.compression_ratio = self.compression_ratio c.compression_ratio = self.compression_ratio
c.upscale_algorithm = self.upscale_algorithm c.upscale_algorithm = self.upscale_algorithm
c.latent_format = self.latent_format
c.vae = self.vae
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
...@@ -129,7 +136,7 @@ class ControlBase: ...@@ -129,7 +136,7 @@ class ControlBase:
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
...@@ -140,6 +147,7 @@ class ControlNet(ControlBase): ...@@ -140,6 +147,7 @@ class ControlNet(ControlBase):
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
...@@ -162,7 +170,17 @@ class ControlNet(ControlBase): ...@@ -162,7 +170,17 @@ class ControlNet(ControlBase):
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device) compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
...@@ -341,7 +359,9 @@ def load_controlnet_mmdit(sd): ...@@ -341,7 +359,9 @@ def load_controlnet_mmdit(sd):
if len(unexpected) > 0: if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
control = ControlNet(control_model, compression_ratio=1, load_device=load_device, manual_cast_dtype=manual_cast_dtype) latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
......
...@@ -80,8 +80,23 @@ class CLIPTextEncodeSD3: ...@@ -80,8 +80,23 @@ class CLIPTextEncodeSD3:
return ([[cond, {"pooled_output": pooled}]], ) return ([[cond, {"pooled_output": pooled}]], )
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
CATEGORY = "_for_testing/sd3"
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader, "TripleCLIPLoader": TripleCLIPLoader,
"EmptySD3LatentImage": EmptySD3LatentImage, "EmptySD3LatentImage": EmptySD3LatentImage,
"CLIPTextEncodeSD3": CLIPTextEncodeSD3, "CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3,
} }
...@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced: ...@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent): def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
if strength == 0: if strength == 0:
return (positive, negative) return (positive, negative)
...@@ -800,7 +800,7 @@ class ControlNetApplyAdvanced: ...@@ -800,7 +800,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets: if prev_cnet in cnets:
c_net = cnets[prev_cnet] c_net = cnets[prev_cnet]
else: else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent)) c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net.set_previous_controlnet(prev_cnet) c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net cnets[prev_cnet] = c_net
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment