Commit 603f02d6 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix loras not working when loading checkpoint with config.

parent ccb1b259
...@@ -140,8 +140,10 @@ def unet_to_diffusers(unet_config): ...@@ -140,8 +140,10 @@ def unet_to_diffusers(unet_config):
channel_mult = unet_config["channel_mult"] channel_mult = unet_config["channel_mult"]
transformer_depth = unet_config["transformer_depth"] transformer_depth = unet_config["transformer_depth"]
num_blocks = len(channel_mult) num_blocks = len(channel_mult)
if not isinstance(num_res_blocks, list): if isinstance(num_res_blocks, int):
num_res_blocks = [num_res_blocks] * num_blocks num_res_blocks = [num_res_blocks] * num_blocks
if isinstance(transformer_depth, int):
transformer_depth = [transformer_depth] * num_blocks
transformers_per_layer = [] transformers_per_layer = []
res = 1 res = 1
...@@ -152,7 +154,7 @@ def unet_to_diffusers(unet_config): ...@@ -152,7 +154,7 @@ def unet_to_diffusers(unet_config):
transformers_per_layer.append(transformers) transformers_per_layer.append(transformers)
res *= 2 res *= 2
transformers_mid = unet_config.get("transformer_depth_middle", transformers_per_layer[-1]) transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
diffusers_unet_map = {} diffusers_unet_map = {}
for x in range(num_blocks): for x in range(num_blocks):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment