Commit 65cae62c authored by comfyanonymous's avatar comfyanonymous
Browse files

No need to check filename extensions to detect shuffle controlnet.

parent 4e89b2c2
import torch import torch
import math import math
import os
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import comfy.model_detection import comfy.model_detection
...@@ -386,7 +387,8 @@ def load_controlnet(ckpt_path, model=None): ...@@ -386,7 +387,8 @@ def load_controlnet(ckpt_path, model=None):
control_model = control_model.half() control_model = control_model.half()
global_average_pooling = False global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling) control = ControlNet(control_model, global_average_pooling=global_average_pooling)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment