"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "05bd7897912c8bbef2145de78a44477d19ec76dc"
Commit 35a4d011 authored by April Hu's avatar April Hu
Browse files

Add support to enable flux.1 tools in ComfyUI with int4

parent 50139c73
......@@ -7,16 +7,17 @@ import comfy.sd
import folder_paths
import GPUtil
import torch
import numpy as np
from comfy.ldm.common_dit import pad_to_patch_size
from comfy.supported_models import Flux, FluxSchnell
from diffusers import FluxTransformer2DModel
from einops import rearrange, repeat
from torch import nn
from transformers import T5EncoderModel
from image_gen_aux import DepthPreprocessor
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
class ComfyUIFluxForwardWrapper(nn.Module):
def __init__(self, model: NunchakuFluxTransformer2dModel, config):
super(ComfyUIFluxForwardWrapper, self).__init__()
......@@ -24,13 +25,25 @@ class ComfyUIFluxForwardWrapper(nn.Module):
self.dtype = next(model.parameters()).dtype
self.config = config
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
def forward(
self,
x,
timestep,
context,
y,
guidance,
control=None,
transformer_options={},
**kwargs,
):
assert control is None # for now
bs, c, h, w = x.shape
patch_size = self.config["patch_size"]
x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
img = rearrange(
x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size
)
h_len = (h + (patch_size // 2)) // patch_size
w_len = (w + (patch_size // 2)) // patch_size
......@@ -54,21 +67,30 @@ class ComfyUIFluxForwardWrapper(nn.Module):
guidance=guidance if self.config["guidance_embed"] else None,
).sample
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
out = rearrange(
out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2
)[:, :, :h, :w]
return out
class SVDQuantFluxDiTLoader:
@classmethod
def INPUT_TYPES(s):
model_paths = ["mit-han-lab/svdq-int4-flux.1-schnell", "mit-han-lab/svdq-int4-flux.1-dev"]
model_paths = [
"mit-han-lab/svdq-int4-flux.1-schnell",
"mit-han-lab/svdq-int4-flux.1-dev",
"mit-han-lab/svdq-int4-flux.1-canny-dev",
"mit-han-lab/svdq-int4-flux.1-depth-dev",
"mit-han-lab/svdq-int4-flux.1-fill-dev",
]
prefix = "models/diffusion_models"
local_folders = os.listdir(prefix)
local_folders = sorted(
[
folder
for folder in local_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
......@@ -78,7 +100,14 @@ class SVDQuantFluxDiTLoader:
"model_path": (model_paths,),
"device_id": (
"INT",
{"default": 0, "min": 0, "max": ngpus, "step": 1, "display": "number", "lazy": True},
{
"default": 0,
"min": 0,
"max": ngpus,
"step": 1,
"display": "number",
"lazy": True,
},
),
}
}
......@@ -88,17 +117,20 @@ class SVDQuantFluxDiTLoader:
CATEGORY = "SVDQuant"
TITLE = "SVDQuant Flux DiT Loader"
def load_model(self, model_path: str, device_id: int, **kwargs) -> tuple[FluxTransformer2DModel]:
def load_model(
self, model_path: str, device_id: int, **kwargs
) -> tuple[FluxTransformer2DModel]:
device = f"cuda:{device_id}"
prefix = "models/diffusion_models"
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
else:
model_path = model_path
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path).to(device)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path).to(
device
)
dit_config = {
"image_model": "flux",
"in_channels": 16,
"patch_size": 2,
"out_channels": 16,
"vec_in_dim": 768,
......@@ -111,21 +143,34 @@ class SVDQuantFluxDiTLoader:
"axes_dim": [16, 56, 56],
"theta": 10000,
"qkv_bias": True,
"guidance_embed": True,
"disable_unet_model_creation": True,
}
if "schnell" in model_path:
dit_config["guidance_embed"] = False
dit_config["in_channels"] = 16
model_config = FluxSchnell(dit_config)
elif "canny" in model_path or "depth" in model_path:
dit_config["in_channels"] = 32
model_config = Flux(dit_config)
elif "fill" in model_path:
dit_config["in_channels"] = 64
model_config = Flux(dit_config)
else:
assert "dev" in model_path
dit_config["guidance_embed"] = True
assert (
model_path == "mit-han-lab/svdq-int4-flux.1-dev"
), f"model {model_path} not supported"
dit_config["in_channels"] = 16
model_config = Flux(dit_config)
model_config.set_inference_dtype(torch.bfloat16, None)
model_config.custom_operations = None
model = model_config.get_model({})
model.diffusion_model = ComfyUIFluxForwardWrapper(transformer, config=dit_config)
model.diffusion_model = ComfyUIFluxForwardWrapper(
transformer, config=dit_config
)
model = comfy.model_patcher.ModelPatcher(model, device, device_id)
return (model,)
......@@ -157,7 +202,8 @@ class SVDQuantTextEncoderLoader:
[
folder
for folder in local_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
......@@ -168,7 +214,14 @@ class SVDQuantTextEncoderLoader:
"text_encoder2": (folder_paths.get_filename_list("text_encoders"),),
"t5_min_length": (
"INT",
{"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True},
{
"default": 512,
"min": 256,
"max": 1024,
"step": 128,
"display": "number",
"lazy": True,
},
),
"t5_precision": (["BF16", "INT4"],),
"int4_model": (model_paths, {"tooltip": "The name of the INT4 model."}),
......@@ -191,8 +244,12 @@ class SVDQuantTextEncoderLoader:
t5_precision: str,
int4_model: str,
):
text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1)
text_encoder_path2 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder2)
text_encoder_path1 = folder_paths.get_full_path_or_raise(
"text_encoders", text_encoder1
)
text_encoder_path2 = folder_paths.get_full_path_or_raise(
"text_encoders", text_encoder2
)
if model_type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
else:
......@@ -223,7 +280,9 @@ class SVDQuantTextEncoderLoader:
transformer = NunchakuT5EncoderModel.from_pretrained(model_path)
transformer.forward = types.MethodType(svdquant_t5_forward, transformer)
clip.cond_stage_model.t5xxl.transformer = (
transformer.to(device=device, dtype=dtype) if device.type == "cuda" else transformer
transformer.to(device=device, dtype=dtype)
if device.type == "cuda"
else transformer
)
return (clip,)
......@@ -239,11 +298,17 @@ class SVDQuantLoraLoader:
lora_name_list = [
"None",
*folder_paths.get_filename_list("loras"),
*[f"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{n}.safetensors" for n in hf_lora_names],
*[
f"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{n}.safetensors"
for n in hf_lora_names
],
]
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"model": (
"MODEL",
{"tooltip": "The diffusion model the LoRA will be applied to."},
),
"lora_name": (lora_name_list, {"tooltip": "The name of the LoRA."}),
"lora_strength": (
"FLOAT",
......@@ -292,8 +357,50 @@ class SVDQuantLoraLoader:
return (model,)
class DepthPreprocesser:
@classmethod
def INPUT_TYPES(s):
model_paths = ["LiheYoung/depth-anything-large-hf"]
prefix = "models/style_models"
local_folders = os.listdir(prefix)
local_folders = sorted(
[
folder
for folder in local_folders
if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
return {
"required": {
"image": ("IMAGE", {}),
"model_path": (
model_paths,
{"tooltip": "Name of the depth preprocesser model."},
),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "depth_preprocess"
CATEGORY = "Flux.1"
TITLE = "Flux.1 Depth Preprocessor"
def depth_preprocess(self, image, model_path):
prefix = "models/style_models"
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
processor = DepthPreprocessor.from_pretrained(model_path)
np_image = np.asarray(image)
np_result = np.array(processor(np_image)[0].convert("RGB"))
out_tensor = torch.from_numpy(np_result.astype(np.float32) / 255.0).unsqueeze(0)
return (out_tensor,)
NODE_CLASS_MAPPINGS = {
"SVDQuantFluxDiTLoader": SVDQuantFluxDiTLoader,
"SVDQuantTextEncoderLoader": SVDQuantTextEncoderLoader,
"SVDQuantLoRALoader": SVDQuantLoraLoader,
"DepthPreprocesser": DepthPreprocesser,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment