Unverified Commit 3fb54988 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #174 from botnang/fix/auto-cpu-offload

Auto CPU Offload in flux.py: add the auto option in the CPU offloading
parents 3ed7a853 04611d15
import os
import comfy.model_patcher
import folder_paths
import GPUtil
import torch
from comfy.ldm.common_dit import pad_to_patch_size
from comfy.supported_models import Flux, FluxSchnell
from diffusers import FluxTransformer2DModel
from einops import rearrange, repeat
from torch import nn
from nunchaku import NunchakuFluxTransformer2dModel
class ComfyUIFluxForwardWrapper(nn.Module):
def __init__(self, model: NunchakuFluxTransformer2dModel, config):
super(ComfyUIFluxForwardWrapper, self).__init__()
......@@ -63,7 +59,6 @@ class ComfyUIFluxForwardWrapper(nn.Module):
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
return out
class SVDQuantFluxDiTLoader:
@classmethod
def INPUT_TYPES(s):
......@@ -89,7 +84,7 @@ class SVDQuantFluxDiTLoader:
local_folders.update(local_folders_)
local_folders = sorted(list(local_folders))
model_paths = local_folders + model_paths
ngpus = len(GPUtil.getGPUs())
ngpus = torch.cuda.device_count()
return {
"required": {
"model_path": (
......@@ -97,10 +92,10 @@ class SVDQuantFluxDiTLoader:
{"tooltip": "The SVDQuant quantized FLUX.1 models. It can be a huggingface path or a local path."},
),
"cpu_offload": (
["enable", "disable"],
["auto", "enable", "disable"],
{
"default": "disable",
"tooltip": "Whether to enable CPU offload for the transformer model. This may slow down the inference, but may reduce the GPU memory usage.",
"default": "auto",
"tooltip": "Whether to enable CPU offload for the transformer model. 'auto' will enable it if the GPU memory is less than 14G.",
},
),
"device_id": (
......@@ -108,7 +103,7 @@ class SVDQuantFluxDiTLoader:
{
"default": 0,
"min": 0,
"max": ngpus,
"max": ngpus - 1,
"step": 1,
"display": "number",
"lazy": True,
......@@ -130,7 +125,36 @@ class SVDQuantFluxDiTLoader:
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
break
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path, offload=cpu_offload == "enable")
# 验证 device_id 是否有效
if device_id >= torch.cuda.device_count():
raise ValueError(f"Invalid device_id: {device_id}. Only {torch.cuda.device_count()} GPUs available.")
# 获取 ComfyUI 指定 CUDA 设备的显存信息
gpu_properties = torch.cuda.get_device_properties(device_id)
gpu_memory = gpu_properties.total_memory / (1024 ** 2) # 转换为 MB
gpu_name = gpu_properties.name
print(f"GPU {device_id} ({gpu_name}) 显存: {gpu_memory} MB")
# 确定 CPU offload 是否启用
if cpu_offload == "auto":
if gpu_memory < 14336: # 14GB 阈值
cpu_offload_enabled = True
print("因显存小于14GB,启用 CPU offload")
else:
cpu_offload_enabled = False
print("显存大于14GB,不启用 CPU offload")
elif cpu_offload == "enable":
cpu_offload_enabled = True
print("用户启用 CPU offload")
else:
cpu_offload_enabled = False
print("用户禁用 CPU offload")
# 清理 GPU 缓存
# torch.cuda.empty_cache()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path, offload=cpu_offload_enabled)
transformer = transformer.to(device)
dit_config = {
"image_model": "flux",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment