Commit 65cae62c authored by comfyanonymous's avatar comfyanonymous
Browse files

No need to check filename extensions to detect shuffle controlnet.

parent 4e89b2c2
import torch
import math
import os
import comfy.utils
import comfy.model_management
import comfy.model_detection
......@@ -386,7 +387,8 @@ def load_controlnet(ckpt_path, model=None):
control_model = control_model.half()
global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment