Commit 2c4e92a9 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix regression.

parent 5eddfdd8
...@@ -48,7 +48,12 @@ def detect_unet_config(state_dict, key_prefix, dtype): ...@@ -48,7 +48,12 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config["dtype"] = dtype unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
out_channels = state_dict[out_key].shape[0]
else:
out_channels = 4
num_res_blocks = [] num_res_blocks = []
channel_mult = [] channel_mult = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment