Unverified Commit 380bfd82 authored by cmdr2's avatar cmdr2 Committed by GitHub
Browse files

Allow controlnets to be loaded (from ckpt) in a parallel thread with a SD...

Allow controlnets to be loaded (from ckpt) in a parallel thread with a SD model (ckpt), and speed it up slightly (#4298)

Faster controlnet model instantiation, and allow controlnets to be loaded (from ckpt) in a parallel thread with a SD model (ckpt) without  tensor errors (race condition)
parent 5989a85e
......@@ -1079,7 +1079,9 @@ def convert_controlnet_checkpoint(
if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
controlnet = ControlNetModel(**ctrlnet_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
controlnet = ControlNetModel(**ctrlnet_config)
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
......@@ -1097,7 +1099,11 @@ def convert_controlnet_checkpoint(
skip_extract_state_dict=skip_extract_state_dict,
)
controlnet.load_state_dict(converted_ctrl_checkpoint)
if is_accelerate_available():
for param_name, param in converted_ctrl_checkpoint.items():
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
else:
controlnet.load_state_dict(converted_ctrl_checkpoint)
return controlnet
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment