Commit 35a4d011 authored by April Hu's avatar April Hu
Browse files

Add support to enable flux.1 tools in ComfyUI with int4

parent 50139c73
...@@ -7,16 +7,17 @@ import comfy.sd ...@@ -7,16 +7,17 @@ import comfy.sd
import folder_paths import folder_paths
import GPUtil import GPUtil
import torch import torch
import numpy as np
from comfy.ldm.common_dit import pad_to_patch_size from comfy.ldm.common_dit import pad_to_patch_size
from comfy.supported_models import Flux, FluxSchnell from comfy.supported_models import Flux, FluxSchnell
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn from torch import nn
from transformers import T5EncoderModel from transformers import T5EncoderModel
from image_gen_aux import DepthPreprocessor
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
class ComfyUIFluxForwardWrapper(nn.Module): class ComfyUIFluxForwardWrapper(nn.Module):
def __init__(self, model: NunchakuFluxTransformer2dModel, config): def __init__(self, model: NunchakuFluxTransformer2dModel, config):
super(ComfyUIFluxForwardWrapper, self).__init__() super(ComfyUIFluxForwardWrapper, self).__init__()
...@@ -24,13 +25,25 @@ class ComfyUIFluxForwardWrapper(nn.Module): ...@@ -24,13 +25,25 @@ class ComfyUIFluxForwardWrapper(nn.Module):
self.dtype = next(model.parameters()).dtype self.dtype = next(model.parameters()).dtype
self.config = config self.config = config
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs): def forward(
self,
x,
timestep,
context,
y,
guidance,
control=None,
transformer_options={},
**kwargs,
):
assert control is None # for now assert control is None # for now
bs, c, h, w = x.shape bs, c, h, w = x.shape
patch_size = self.config["patch_size"] patch_size = self.config["patch_size"]
x = pad_to_patch_size(x, (patch_size, patch_size)) x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) img = rearrange(
x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size
)
h_len = (h + (patch_size // 2)) // patch_size h_len = (h + (patch_size // 2)) // patch_size
w_len = (w + (patch_size // 2)) // patch_size w_len = (w + (patch_size // 2)) // patch_size
...@@ -54,21 +67,30 @@ class ComfyUIFluxForwardWrapper(nn.Module): ...@@ -54,21 +67,30 @@ class ComfyUIFluxForwardWrapper(nn.Module):
guidance=guidance if self.config["guidance_embed"] else None, guidance=guidance if self.config["guidance_embed"] else None,
).sample ).sample
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w] out = rearrange(
out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2
)[:, :, :h, :w]
return out return out
class SVDQuantFluxDiTLoader: class SVDQuantFluxDiTLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
model_paths = ["mit-han-lab/svdq-int4-flux.1-schnell", "mit-han-lab/svdq-int4-flux.1-dev"] model_paths = [
"mit-han-lab/svdq-int4-flux.1-schnell",
"mit-han-lab/svdq-int4-flux.1-dev",
"mit-han-lab/svdq-int4-flux.1-canny-dev",
"mit-han-lab/svdq-int4-flux.1-depth-dev",
"mit-han-lab/svdq-int4-flux.1-fill-dev",
]
prefix = "models/diffusion_models" prefix = "models/diffusion_models"
local_folders = os.listdir(prefix) local_folders = os.listdir(prefix)
local_folders = sorted( local_folders = sorted(
[ [
folder folder
for folder in local_folders for folder in local_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder)) if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
] ]
) )
model_paths.extend(local_folders) model_paths.extend(local_folders)
...@@ -78,7 +100,14 @@ class SVDQuantFluxDiTLoader: ...@@ -78,7 +100,14 @@ class SVDQuantFluxDiTLoader:
"model_path": (model_paths,), "model_path": (model_paths,),
"device_id": ( "device_id": (
"INT", "INT",
{"default": 0, "min": 0, "max": ngpus, "step": 1, "display": "number", "lazy": True}, {
"default": 0,
"min": 0,
"max": ngpus,
"step": 1,
"display": "number",
"lazy": True,
},
), ),
} }
} }
...@@ -88,17 +117,20 @@ class SVDQuantFluxDiTLoader: ...@@ -88,17 +117,20 @@ class SVDQuantFluxDiTLoader:
CATEGORY = "SVDQuant" CATEGORY = "SVDQuant"
TITLE = "SVDQuant Flux DiT Loader" TITLE = "SVDQuant Flux DiT Loader"
def load_model(self, model_path: str, device_id: int, **kwargs) -> tuple[FluxTransformer2DModel]: def load_model(
self, model_path: str, device_id: int, **kwargs
) -> tuple[FluxTransformer2DModel]:
device = f"cuda:{device_id}" device = f"cuda:{device_id}"
prefix = "models/diffusion_models" prefix = "models/diffusion_models"
if os.path.exists(os.path.join(prefix, model_path)): if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path) model_path = os.path.join(prefix, model_path)
else: else:
model_path = model_path model_path = model_path
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path).to(device) transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path).to(
device
)
dit_config = { dit_config = {
"image_model": "flux", "image_model": "flux",
"in_channels": 16,
"patch_size": 2, "patch_size": 2,
"out_channels": 16, "out_channels": 16,
"vec_in_dim": 768, "vec_in_dim": 768,
...@@ -111,21 +143,34 @@ class SVDQuantFluxDiTLoader: ...@@ -111,21 +143,34 @@ class SVDQuantFluxDiTLoader:
"axes_dim": [16, 56, 56], "axes_dim": [16, 56, 56],
"theta": 10000, "theta": 10000,
"qkv_bias": True, "qkv_bias": True,
"guidance_embed": True,
"disable_unet_model_creation": True, "disable_unet_model_creation": True,
} }
if "schnell" in model_path: if "schnell" in model_path:
dit_config["guidance_embed"] = False dit_config["guidance_embed"] = False
dit_config["in_channels"] = 16
model_config = FluxSchnell(dit_config) model_config = FluxSchnell(dit_config)
elif "canny" in model_path or "depth" in model_path:
dit_config["in_channels"] = 32
model_config = Flux(dit_config)
elif "fill" in model_path:
dit_config["in_channels"] = 64
model_config = Flux(dit_config)
else: else:
assert "dev" in model_path assert (
dit_config["guidance_embed"] = True model_path == "mit-han-lab/svdq-int4-flux.1-dev"
), f"model {model_path} not supported"
dit_config["in_channels"] = 16
model_config = Flux(dit_config) model_config = Flux(dit_config)
model_config.set_inference_dtype(torch.bfloat16, None) model_config.set_inference_dtype(torch.bfloat16, None)
model_config.custom_operations = None model_config.custom_operations = None
model = model_config.get_model({}) model = model_config.get_model({})
model.diffusion_model = ComfyUIFluxForwardWrapper(transformer, config=dit_config) model.diffusion_model = ComfyUIFluxForwardWrapper(
transformer, config=dit_config
)
model = comfy.model_patcher.ModelPatcher(model, device, device_id) model = comfy.model_patcher.ModelPatcher(model, device, device_id)
return (model,) return (model,)
...@@ -157,7 +202,8 @@ class SVDQuantTextEncoderLoader: ...@@ -157,7 +202,8 @@ class SVDQuantTextEncoderLoader:
[ [
folder folder
for folder in local_folders for folder in local_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder)) if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
] ]
) )
model_paths.extend(local_folders) model_paths.extend(local_folders)
...@@ -168,7 +214,14 @@ class SVDQuantTextEncoderLoader: ...@@ -168,7 +214,14 @@ class SVDQuantTextEncoderLoader:
"text_encoder2": (folder_paths.get_filename_list("text_encoders"),), "text_encoder2": (folder_paths.get_filename_list("text_encoders"),),
"t5_min_length": ( "t5_min_length": (
"INT", "INT",
{"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True}, {
"default": 512,
"min": 256,
"max": 1024,
"step": 128,
"display": "number",
"lazy": True,
},
), ),
"t5_precision": (["BF16", "INT4"],), "t5_precision": (["BF16", "INT4"],),
"int4_model": (model_paths, {"tooltip": "The name of the INT4 model."}), "int4_model": (model_paths, {"tooltip": "The name of the INT4 model."}),
...@@ -191,8 +244,12 @@ class SVDQuantTextEncoderLoader: ...@@ -191,8 +244,12 @@ class SVDQuantTextEncoderLoader:
t5_precision: str, t5_precision: str,
int4_model: str, int4_model: str,
): ):
text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1) text_encoder_path1 = folder_paths.get_full_path_or_raise(
text_encoder_path2 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder2) "text_encoders", text_encoder1
)
text_encoder_path2 = folder_paths.get_full_path_or_raise(
"text_encoders", text_encoder2
)
if model_type == "flux": if model_type == "flux":
clip_type = comfy.sd.CLIPType.FLUX clip_type = comfy.sd.CLIPType.FLUX
else: else:
...@@ -223,7 +280,9 @@ class SVDQuantTextEncoderLoader: ...@@ -223,7 +280,9 @@ class SVDQuantTextEncoderLoader:
transformer = NunchakuT5EncoderModel.from_pretrained(model_path) transformer = NunchakuT5EncoderModel.from_pretrained(model_path)
transformer.forward = types.MethodType(svdquant_t5_forward, transformer) transformer.forward = types.MethodType(svdquant_t5_forward, transformer)
clip.cond_stage_model.t5xxl.transformer = ( clip.cond_stage_model.t5xxl.transformer = (
transformer.to(device=device, dtype=dtype) if device.type == "cuda" else transformer transformer.to(device=device, dtype=dtype)
if device.type == "cuda"
else transformer
) )
return (clip,) return (clip,)
...@@ -239,11 +298,17 @@ class SVDQuantLoraLoader: ...@@ -239,11 +298,17 @@ class SVDQuantLoraLoader:
lora_name_list = [ lora_name_list = [
"None", "None",
*folder_paths.get_filename_list("loras"), *folder_paths.get_filename_list("loras"),
*[f"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{n}.safetensors" for n in hf_lora_names], *[
f"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{n}.safetensors"
for n in hf_lora_names
],
] ]
return { return {
"required": { "required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), "model": (
"MODEL",
{"tooltip": "The diffusion model the LoRA will be applied to."},
),
"lora_name": (lora_name_list, {"tooltip": "The name of the LoRA."}), "lora_name": (lora_name_list, {"tooltip": "The name of the LoRA."}),
"lora_strength": ( "lora_strength": (
"FLOAT", "FLOAT",
...@@ -292,8 +357,50 @@ class SVDQuantLoraLoader: ...@@ -292,8 +357,50 @@ class SVDQuantLoraLoader:
return (model,) return (model,)
class DepthPreprocesser:
@classmethod
def INPUT_TYPES(s):
model_paths = ["LiheYoung/depth-anything-large-hf"]
prefix = "models/style_models"
local_folders = os.listdir(prefix)
local_folders = sorted(
[
folder
for folder in local_folders
if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
return {
"required": {
"image": ("IMAGE", {}),
"model_path": (
model_paths,
{"tooltip": "Name of the depth preprocesser model."},
),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "depth_preprocess"
CATEGORY = "Flux.1"
TITLE = "Flux.1 Depth Preprocessor"
def depth_preprocess(self, image, model_path):
prefix = "models/style_models"
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
processor = DepthPreprocessor.from_pretrained(model_path)
np_image = np.asarray(image)
np_result = np.array(processor(np_image)[0].convert("RGB"))
out_tensor = torch.from_numpy(np_result.astype(np.float32) / 255.0).unsqueeze(0)
return (out_tensor,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SVDQuantFluxDiTLoader": SVDQuantFluxDiTLoader, "SVDQuantFluxDiTLoader": SVDQuantFluxDiTLoader,
"SVDQuantTextEncoderLoader": SVDQuantTextEncoderLoader, "SVDQuantTextEncoderLoader": SVDQuantTextEncoderLoader,
"SVDQuantLoRALoader": SVDQuantLoraLoader, "SVDQuantLoRALoader": SVDQuantLoraLoader,
"DepthPreprocesser": DepthPreprocesser,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment