"vscode:/vscode.git/clone" did not exist on "8e1f01c5190f4d1c0603ac1a3e9da257cb50bd40"
Commit d80af7ca authored by comfyanonymous's avatar comfyanonymous
Browse files

ControlNetApply now stacks.

It can be used to apply multiple control nets at the same time.
parent 8683ea42
...@@ -334,8 +334,13 @@ class ControlNet: ...@@ -334,8 +334,13 @@ class ControlNet:
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
self.device = device self.device = device
self.previous_controlnet = None
def get_control(self, x_noisy, t, cond_txt): def get_control(self, x_noisy, t, cond_txt):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt)
output_dtype = x_noisy.dtype output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
...@@ -354,10 +359,15 @@ class ControlNet: ...@@ -354,10 +359,15 @@ class ControlNet:
self.control_model = model_management.unload_if_low_vram(self.control_model) self.control_model = model_management.unload_if_low_vram(self.control_model)
out = [] out = []
autocast_enabled = torch.is_autocast_enabled() autocast_enabled = torch.is_autocast_enabled()
for x in control:
for i in range(len(control)):
x = control[i]
x *= self.strength x *= self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype) x = x.to(output_dtype)
if control_prev is not None:
x += control_prev[i]
out.append(x) out.append(x)
return out return out
...@@ -366,7 +376,13 @@ class ControlNet: ...@@ -366,7 +376,13 @@ class ControlNet:
self.strength = strength self.strength = strength
return self return self
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def cleanup(self): def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
...@@ -377,6 +393,13 @@ class ControlNet: ...@@ -377,6 +393,13 @@ class ControlNet:
c.strength = self.strength c.strength = self.strength
return c return c
def get_control_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models()
out.append(self.control_model)
return out
def load_controlnet(ckpt_path): def load_controlnet(ckpt_path):
controlnet_data = load_torch_file(ckpt_path) controlnet_data = load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
......
...@@ -252,7 +252,10 @@ class ControlNetApply: ...@@ -252,7 +252,10 @@ class ControlNetApply:
print(control_hint.shape) print(control_hint.shape)
for t in conditioning: for t in conditioning:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength) c_net = control_net.copy().set_cond_hint(control_hint, strength)
if 'control' in t[1]:
c_net.set_previous_controlnet(t[1]['control'])
n[1]['control'] = c_net
c.append(n) c.append(n)
return (c, ) return (c, )
...@@ -510,7 +513,10 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po ...@@ -510,7 +513,10 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
control_nets += [p[1]['control']] control_nets += [p[1]['control']]
negative_copy += [[t] + n[1:]] negative_copy += [[t] + n[1:]]
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets))) control_net_models = []
for x in control_nets:
control_net_models += x.get_control_models()
model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment