Commit 7d401ed1 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add ldm format support to UNETLoader.

parent 9562a6b4
...@@ -454,20 +454,26 @@ def load_unet(unet_path): #load unet in diffusers format ...@@ -454,20 +454,26 @@ def load_unet(unet_path): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path) sd = comfy.utils.load_torch_file(unet_path)
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
fp16 = model_management.should_use_fp16(model_params=parameters) fp16 = model_management.should_use_fp16(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) model_config = model_detection.model_config_from_unet(sd, "", fp16)
if model_config is None: if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return None new_sd = sd
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
new_sd = {} if model_config is None:
for k in diffusers_keys: print("ERROR UNSUPPORTED UNET", unet_path)
if k in sd: return None
new_sd[diffusers_keys[k]] = sd.pop(k)
else: diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
print(diffusers_keys[k], k)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment