Commit 742a8006 authored by Muyang Li's avatar Muyang Li Committed by Zhekai Zhang
Browse files

[major] Fix the tempfile bug in the comfyui

parent 27232e7b
import logging
import os import os
import tempfile
import folder_paths import folder_paths
from safetensors.torch import save_file from safetensors.torch import save_file
from nunchaku.lora.flux import comfyui2diffusers, convert_to_nunchaku_flux_lowrank_dict, detect_format, xlab2diffusers from nunchaku.lora.flux import comfyui2diffusers, convert_to_nunchaku_flux_lowrank_dict, detect_format, xlab2diffusers
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("SVDQuantFluxLoraLoader")
class SVDQuantFluxLoraLoader: class SVDQuantFluxLoraLoader:
def __init__(self): def __init__(self):
...@@ -13,31 +16,20 @@ class SVDQuantFluxLoraLoader: ...@@ -13,31 +16,20 @@ class SVDQuantFluxLoraLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
lora_name_list = [ lora_name_list = ["None", *folder_paths.get_filename_list("loras")]
"None",
*folder_paths.get_filename_list("loras"),
"aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
]
base_model_paths = [ prefixes = folder_paths.folder_names_and_paths["diffusion_models"][0]
"mit-han-lab/svdq-int4-flux.1-dev", base_model_paths = set()
"mit-han-lab/svdq-int4-flux.1-schnell", for prefix in prefixes:
"mit-han-lab/svdq-fp4-flux.1-dev", if os.path.exists(prefix) and os.path.isdir(prefix):
"mit-han-lab/svdq-fp4-flux.1-schnell", base_model_paths_ = os.listdir(prefix)
"mit-han-lab/svdq-int4-flux.1-canny-dev", base_model_paths_ = [
"mit-han-lab/svdq-int4-flux.1-depth-dev", folder
"mit-han-lab/svdq-int4-flux.1-fill-dev", for folder in base_model_paths_
] if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
prefix = os.path.join(folder_paths.models_dir, "diffusion_models") ]
local_base_model_folders = os.listdir(prefix) base_model_paths.update(base_model_paths_)
local_base_model_folders = sorted( base_model_paths = sorted(list(base_model_paths))
[
folder
for folder in local_base_model_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
]
)
base_model_paths = local_base_model_folders + base_model_paths
return { return {
"required": { "required": {
...@@ -63,6 +55,12 @@ class SVDQuantFluxLoraLoader: ...@@ -63,6 +55,12 @@ class SVDQuantFluxLoraLoader:
"tooltip": "How strongly to modify the diffusion model. This value can be negative.", "tooltip": "How strongly to modify the diffusion model. This value can be negative.",
}, },
), ),
"save_converted_lora": (
["disable", "enable"],
{
"tooltip": "If enabled, the converted LoRA will be saved as a .safetensors file in the save directory of your LoRA file."
},
),
} }
} }
...@@ -78,7 +76,15 @@ class SVDQuantFluxLoraLoader: ...@@ -78,7 +76,15 @@ class SVDQuantFluxLoraLoader:
"Currently, only one LoRA nodes can be applied." "Currently, only one LoRA nodes can be applied."
) )
def load_lora(self, model, lora_name: str, lora_format: str, base_model_name: str, lora_strength: float): def load_lora(
self,
model,
lora_name: str,
lora_format: str,
base_model_name: str,
lora_strength: float,
save_converted_lora: str,
):
if self.cur_lora_name == lora_name: if self.cur_lora_name == lora_name:
if self.cur_lora_name == "None": if self.cur_lora_name == "None":
pass # Do nothing since the lora is None pass # Do nothing since the lora is None
...@@ -110,9 +116,22 @@ class SVDQuantFluxLoraLoader: ...@@ -110,9 +116,22 @@ class SVDQuantFluxLoraLoader:
base_model_path = os.path.join(base_model_name, "transformer_blocks.safetensors") base_model_path = os.path.join(base_model_name, "transformer_blocks.safetensors")
state_dict = convert_to_nunchaku_flux_lowrank_dict(base_model_path, input_lora) state_dict = convert_to_nunchaku_flux_lowrank_dict(base_model_path, input_lora)
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=True) as tmp_file: if save_converted_lora == "enable" and lora_format != "svdquant":
save_file(state_dict, tmp_file.name) dirname = os.path.dirname(lora_path)
model.model.diffusion_model.model.update_lora_params(tmp_file.name) basename = os.path.basename(lora_path)
if "int4" in base_model_path:
precision = "int4"
else:
assert "fp4" in base_model_path
precision = "fp4"
converted_name = f"svdq-{precision}-{basename}"
lora_converted_path = os.path.join(dirname, converted_name)
if not os.path.exists(lora_converted_path):
save_file(state_dict, lora_converted_path)
logger.info(f"Saved converted LoRA to: {lora_converted_path}")
else:
logger.info(f"Converted LoRA already exists at: {lora_converted_path}")
model.model.diffusion_model.model.update_lora_params(state_dict)
else: else:
model.model.diffusion_model.model.update_lora_params(lora_path) model.model.diffusion_model.model.update_lora_params(lora_path)
model.model.diffusion_model.model.set_lora_strength(lora_strength) model.model.diffusion_model.model.set_lora_strength(lora_strength)
......
import os import os
import comfy.model_patcher import comfy.model_patcher
import folder_paths import folder_paths
import torch import torch
...@@ -7,8 +8,10 @@ from comfy.supported_models import Flux, FluxSchnell ...@@ -7,8 +8,10 @@ from comfy.supported_models import Flux, FluxSchnell
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn from torch import nn
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
class ComfyUIFluxForwardWrapper(nn.Module): class ComfyUIFluxForwardWrapper(nn.Module):
def __init__(self, model: NunchakuFluxTransformer2dModel, config): def __init__(self, model: NunchakuFluxTransformer2dModel, config):
super(ComfyUIFluxForwardWrapper, self).__init__() super(ComfyUIFluxForwardWrapper, self).__init__()
...@@ -59,18 +62,10 @@ class ComfyUIFluxForwardWrapper(nn.Module): ...@@ -59,18 +62,10 @@ class ComfyUIFluxForwardWrapper(nn.Module):
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w] out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
return out return out
class SVDQuantFluxDiTLoader: class SVDQuantFluxDiTLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
model_paths = [
"mit-han-lab/svdq-int4-flux.1-schnell",
"mit-han-lab/svdq-int4-flux.1-dev",
"mit-han-lab/svdq-fp4-flux.1-schnell",
"mit-han-lab/svdq-fp4-flux.1-dev",
"mit-han-lab/svdq-int4-flux.1-canny-dev",
"mit-han-lab/svdq-int4-flux.1-depth-dev",
"mit-han-lab/svdq-int4-flux.1-fill-dev",
]
prefixes = folder_paths.folder_names_and_paths["diffusion_models"][0] prefixes = folder_paths.folder_names_and_paths["diffusion_models"][0]
local_folders = set() local_folders = set()
for prefix in prefixes: for prefix in prefixes:
...@@ -82,8 +77,7 @@ class SVDQuantFluxDiTLoader: ...@@ -82,8 +77,7 @@ class SVDQuantFluxDiTLoader:
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder)) if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
] ]
local_folders.update(local_folders_) local_folders.update(local_folders_)
local_folders = sorted(list(local_folders)) model_paths = sorted(list(local_folders))
model_paths = local_folders + model_paths
ngpus = torch.cuda.device_count() ngpus = torch.cuda.device_count()
return { return {
"required": { "required": {
...@@ -126,35 +120,37 @@ class SVDQuantFluxDiTLoader: ...@@ -126,35 +120,37 @@ class SVDQuantFluxDiTLoader:
model_path = os.path.join(prefix, model_path) model_path = os.path.join(prefix, model_path)
break break
# 验证 device_id 是否有效 # Check if the device_id is valid
if device_id >= torch.cuda.device_count(): if device_id >= torch.cuda.device_count():
raise ValueError(f"Invalid device_id: {device_id}. Only {torch.cuda.device_count()} GPUs available.") raise ValueError(f"Invalid device_id: {device_id}. Only {torch.cuda.device_count()} GPUs available.")
# 获取 ComfyUI 指定 CUDA 设备的显存信息 # Get the GPU properties
gpu_properties = torch.cuda.get_device_properties(device_id) gpu_properties = torch.cuda.get_device_properties(device_id)
gpu_memory = gpu_properties.total_memory / (1024 ** 2) # 转换为 MB gpu_memory = gpu_properties.total_memory / (1024**2) # Convert to MB
gpu_name = gpu_properties.name gpu_name = gpu_properties.name
print(f"GPU {device_id} ({gpu_name}) 显存: {gpu_memory} MB") print(f"GPU {device_id} ({gpu_name}) Memory: {gpu_memory} MB")
# 确定 CPU offload 是否启用 # Check if CPU offload needs to be enabled
if cpu_offload == "auto": if cpu_offload == "auto":
if gpu_memory < 14336: # 14GB 阈值 if gpu_memory < 14336: # 14GB threshold
cpu_offload_enabled = True cpu_offload_enabled = True
print("因显存小于14GB,启用 CPU offload") print("VRAM < 14GiB,enable CPU offload")
else: else:
cpu_offload_enabled = False cpu_offload_enabled = False
print("显存大于14GB,不启用 CPU offload") print("VRAM > 14GiB,disable CPU offload")
elif cpu_offload == "enable": elif cpu_offload == "enable":
cpu_offload_enabled = True cpu_offload_enabled = True
print("用户启用 CPU offload") print("Enable CPU offload")
else: else:
cpu_offload_enabled = False cpu_offload_enabled = False
print("用户禁用 CPU offload") print("Disable CPU offload")
# 清理 GPU 缓存 capability = torch.cuda.get_device_capability(0)
# torch.cuda.empty_cache() sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path, offload=cpu_offload_enabled) transformer = NunchakuFluxTransformer2dModel.from_pretrained(
model_path, precision=precision, offload=cpu_offload_enabled
)
transformer = transformer.to(device) transformer = transformer.to(device)
dit_config = { dit_config = {
"image_model": "flux", "image_model": "flux",
......
...@@ -45,7 +45,6 @@ class WrappedEmbedding(nn.Module): ...@@ -45,7 +45,6 @@ class WrappedEmbedding(nn.Module):
class SVDQuantTextEncoderLoader: class SVDQuantTextEncoderLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
model_paths = ["mit-han-lab/svdq-flux.1-t5"]
prefixes = folder_paths.folder_names_and_paths["text_encoders"][0] prefixes = folder_paths.folder_names_and_paths["text_encoders"][0]
local_folders = set() local_folders = set()
for prefix in prefixes: for prefix in prefixes:
...@@ -57,8 +56,7 @@ class SVDQuantTextEncoderLoader: ...@@ -57,8 +56,7 @@ class SVDQuantTextEncoderLoader:
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder)) if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
] ]
local_folders.update(local_folders_) local_folders.update(local_folders_)
local_folders = sorted(list(local_folders)) model_paths = sorted(list(local_folders))
model_paths.extend(local_folders)
return { return {
"required": { "required": {
"model_type": (["flux"],), "model_type": (["flux"],),
...@@ -68,8 +66,8 @@ class SVDQuantTextEncoderLoader: ...@@ -68,8 +66,8 @@ class SVDQuantTextEncoderLoader:
"INT", "INT",
{"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True}, {"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True},
), ),
"t5_precision": (["BF16", "INT4"],), "use_4bit_t5": (["disable", "enable"],),
"int4_model": (model_paths, {"tooltip": "The name of the INT4 model."}), "int4_model": (model_paths, {"tooltip": "The name of the 4-bit T5 model."}),
} }
} }
...@@ -86,7 +84,7 @@ class SVDQuantTextEncoderLoader: ...@@ -86,7 +84,7 @@ class SVDQuantTextEncoderLoader:
text_encoder1: str, text_encoder1: str,
text_encoder2: str, text_encoder2: str,
t5_min_length: int, t5_min_length: int,
t5_precision: str, use_4bit_t5: str,
int4_model: str, int4_model: str,
): ):
text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1) text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1)
...@@ -105,7 +103,7 @@ class SVDQuantTextEncoderLoader: ...@@ -105,7 +103,7 @@ class SVDQuantTextEncoderLoader:
if model_type == "flux": if model_type == "flux":
clip.tokenizer.t5xxl.min_length = t5_min_length clip.tokenizer.t5xxl.min_length = t5_min_length
if t5_precision == "INT4": if use_4bit_t5 == "enable":
transformer = clip.cond_stage_model.t5xxl.transformer transformer = clip.cond_stage_model.t5xxl.transformer
param = next(transformer.parameters()) param = next(transformer.parameters())
dtype = param.dtype dtype = param.dtype
......
...@@ -3,13 +3,12 @@ import os ...@@ -3,13 +3,12 @@ import os
import folder_paths import folder_paths
import numpy as np import numpy as np
import torch import torch
from image_gen_aux import DepthPreprocessor
class FluxDepthPreprocessor: class FluxDepthPreprocessor:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
model_paths = ["LiheYoung/depth-anything-large-hf"] model_paths = []
prefix = os.path.join(folder_paths.models_dir, "checkpoints") prefix = os.path.join(folder_paths.models_dir, "checkpoints")
local_folders = os.listdir(prefix) local_folders = os.listdir(prefix)
local_folders = sorted( local_folders = sorted(
...@@ -36,9 +35,13 @@ class FluxDepthPreprocessor: ...@@ -36,9 +35,13 @@ class FluxDepthPreprocessor:
TITLE = "FLUX.1 Depth Preprocessor" TITLE = "FLUX.1 Depth Preprocessor"
def depth_preprocess(self, image, model_path): def depth_preprocess(self, image, model_path):
prefix = os.path.join(folder_paths.models_dir, "checkpoints") prefixes = folder_paths.folder_names_and_paths["checkpoints"][0]
if os.path.exists(os.path.join(prefix, model_path)): for prefix in prefixes:
model_path = os.path.join(prefix, model_path) if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
break
from image_gen_aux import DepthPreprocessor
processor = DepthPreprocessor.from_pretrained(model_path) processor = DepthPreprocessor.from_pretrained(model_path)
np_image = np.asarray(image) np_image = np.asarray(image)
np_result = np.array(processor(np_image)[0].convert("RGB")) np_result = np.array(processor(np_image)[0].convert("RGB"))
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", text_encoder_2=text_encoder_2, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
...@@ -60,12 +60,12 @@ def quantize_t5_encoder( ...@@ -60,12 +60,12 @@ def quantize_t5_encoder(
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict and name.endswith(qlayer_suffix): if f"{name}.qweight" in state_dict and name.endswith(qlayer_suffix):
print(f"Switching {name} to W4Linear") print(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True) qmodule = W4Linear.from_linear(module, group_size=128, init_only=False)
qmodule.qweight.data.copy_(state_dict[f"{name}.qweight"]) # qmodule.qweight.data.copy_(state_dict[f"{name}.qweight"])
if qmodule.bias is not None: # if qmodule.bias is not None:
qmodule.bias.data.copy_(state_dict[f"{name}.bias"]) # qmodule.bias.data.copy_(state_dict[f"{name}.bias"])
qmodule.scales.data.copy_(state_dict[f"{name}.scales"]) # qmodule.scales.data.copy_(state_dict[f"{name}.scales"])
qmodule.scaled_zeros.data.copy_(state_dict[f"{name}.scaled_zeros"]) # qmodule.scaled_zeros.data.copy_(state_dict[f"{name}.scaled_zeros"])
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight # modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device) qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
......
...@@ -8,7 +8,6 @@ from huggingface_hub import utils ...@@ -8,7 +8,6 @@ from huggingface_hub import utils
from packaging.version import Version from packaging.version import Version
from torch import nn from torch import nn
from nunchaku.utils import fetch_or_download
from .utils import NunchakuModelLoaderMixin, pad_tensor from .utils import NunchakuModelLoaderMixin, pad_tensor
from ..._C import QuantizedFluxModel, utils as cutils from ..._C import QuantizedFluxModel, utils as cutils
from ...utils import load_state_dict_in_safetensors from ...utils import load_state_dict_in_safetensors
...@@ -224,13 +223,18 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -224,13 +223,18 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
new_state_dict[k] = v new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=True) self.load_state_dict(new_state_dict, strict=True)
def update_lora_params(self, path: str): def update_lora_params(self, path_or_state_dict: str | dict[str, torch.Tensor]):
state_dict = load_state_dict_in_safetensors(path) if isinstance(path_or_state_dict, dict):
state_dict = path_or_state_dict
else:
state_dict = load_state_dict_in_safetensors(path_or_state_dict)
unquantized_loras = {} unquantized_loras = {}
for k in state_dict.keys(): for k in state_dict.keys():
if "transformer_blocks" not in k: if "transformer_blocks" not in k:
unquantized_loras[k] = state_dict[k] unquantized_loras[k] = state_dict[k]
for k in unquantized_loras.keys():
state_dict.pop(k)
self.unquantized_loras = unquantized_loras self.unquantized_loras = unquantized_loras
if len(unquantized_loras) > 0: if len(unquantized_loras) > 0:
...@@ -239,10 +243,9 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -239,10 +243,9 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self.unquantized_state_dict = {k: v.cpu() for k, v in unquantized_state_dict.items()} self.unquantized_state_dict = {k: v.cpu() for k, v in unquantized_state_dict.items()}
self.update_unquantized_lora_params(1) self.update_unquantized_lora_params(1)
path = fetch_or_download(path)
block = self.transformer_blocks[0] block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks) assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.load(path, True) block.m.loadDict(path_or_state_dict, True)
def set_lora_strength(self, strength: float = 1): def set_lora_strength(self, strength: float = 1):
block = self.transformer_blocks[0] block = self.transformer_blocks[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment