Commit a47f609f authored by comfyanonymous's avatar comfyanonymous
Browse files

Auto detect out_channels from model.

parent 79f73a4b
......@@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config = {
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"use_spatial_transformer": True,
"legacy": False
}
......@@ -49,6 +48,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0]
num_res_blocks = []
channel_mult = []
......@@ -122,6 +122,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
transformer_depth_middle = -1
unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["transformer_depth"] = transformer_depth
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment