"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "6ab8cad22ee376c0a9a506924cca278bb650afcc"
Commit 763b0cf0 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix control lora not working in fp32.

parent bc76b382
...@@ -926,8 +926,8 @@ class ControlLora(ControlNet): ...@@ -926,8 +926,8 @@ class ControlLora(ControlNet):
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps() controlnet_config["operations"] = ControlLoraOps()
self.control_model = cldm.ControlNet(**controlnet_config) self.control_model = cldm.ControlNet(**controlnet_config)
if model_management.should_use_fp16(): dtype = model.get_dtype()
self.control_model.half() self.control_model.to(dtype)
self.control_model.to(model_management.get_torch_device()) self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict() sd = diffusion_model.state_dict()
...@@ -947,7 +947,7 @@ class ControlLora(ControlNet): ...@@ -947,7 +947,7 @@ class ControlLora(ControlNet):
for k in self.control_weights: for k in self.control_weights:
if k not in {"lora_controlnet"}: if k not in {"lora_controlnet"}:
set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device())) set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(model_management.get_torch_device()))
def copy(self): def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment