Unverified Commit 380bfd82 authored by cmdr2's avatar cmdr2 Committed by GitHub
Browse files

Allow controlnets to be loaded (from ckpt) in a parallel thread with a SD...

Allow controlnets to be loaded (from ckpt) in a parallel thread with a SD model (ckpt), and speed it up slightly (#4298)

Faster controlnet model instantiation, and allow controlnets to be loaded (from ckpt) in a parallel thread with a SD model (ckpt) without  tensor errors (race condition)
parent 5989a85e
...@@ -1079,7 +1079,9 @@ def convert_controlnet_checkpoint( ...@@ -1079,7 +1079,9 @@ def convert_controlnet_checkpoint(
if cross_attention_dim is not None: if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim ctrlnet_config["cross_attention_dim"] = cross_attention_dim
controlnet = ControlNetModel(**ctrlnet_config) ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
controlnet = ControlNetModel(**ctrlnet_config)
# Some controlnet ckpt files are distributed independently from the rest of the # Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
...@@ -1097,7 +1099,11 @@ def convert_controlnet_checkpoint( ...@@ -1097,7 +1099,11 @@ def convert_controlnet_checkpoint(
skip_extract_state_dict=skip_extract_state_dict, skip_extract_state_dict=skip_extract_state_dict,
) )
controlnet.load_state_dict(converted_ctrl_checkpoint) if is_accelerate_available():
for param_name, param in converted_ctrl_checkpoint.items():
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
else:
controlnet.load_state_dict(converted_ctrl_checkpoint)
return controlnet return controlnet
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment