"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "37324a0a007036096755e58e2cd4af030cd0a2c5"
Commit d80af7ca authored by comfyanonymous's avatar comfyanonymous
Browse files

ControlNetApply now stacks.

It can be used to apply multiple control nets at the same time.
parent 8683ea42
......@@ -334,8 +334,13 @@ class ControlNet:
self.cond_hint = None
self.strength = 1.0
self.device = device
self.previous_controlnet = None
def get_control(self, x_noisy, t, cond_txt):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt)
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
......@@ -354,10 +359,15 @@ class ControlNet:
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = []
autocast_enabled = torch.is_autocast_enabled()
for x in control:
for i in range(len(control)):
x = control[i]
x *= self.strength
if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype)
if control_prev is not None:
x += control_prev[i]
out.append(x)
return out
......@@ -366,7 +376,13 @@ class ControlNet:
self.strength = strength
return self
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
......@@ -377,6 +393,13 @@ class ControlNet:
c.strength = self.strength
return c
def get_control_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models()
out.append(self.control_model)
return out
def load_controlnet(ckpt_path):
controlnet_data = load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
......
......@@ -252,7 +252,10 @@ class ControlNetApply:
print(control_hint.shape)
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength)
c_net = control_net.copy().set_cond_hint(control_hint, strength)
if 'control' in t[1]:
c_net.set_previous_controlnet(t[1]['control'])
n[1]['control'] = c_net
c.append(n)
return (c, )
......@@ -510,7 +513,10 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
control_nets += [p[1]['control']]
negative_copy += [[t] + n[1:]]
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets)))
control_net_models = []
for x in control_nets:
control_net_models += x.get_control_models()
model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment