Unverified Commit 3fb54988 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #174 from botnang/fix/auto-cpu-offload

Auto CPU Offload in flux.py: add the auto option in the CPU offloading
parents 3ed7a853 04611d15
import os import os
import comfy.model_patcher import comfy.model_patcher
import folder_paths import folder_paths
import GPUtil
import torch import torch
from comfy.ldm.common_dit import pad_to_patch_size from comfy.ldm.common_dit import pad_to_patch_size
from comfy.supported_models import Flux, FluxSchnell from comfy.supported_models import Flux, FluxSchnell
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn from torch import nn
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
class ComfyUIFluxForwardWrapper(nn.Module): class ComfyUIFluxForwardWrapper(nn.Module):
def __init__(self, model: NunchakuFluxTransformer2dModel, config): def __init__(self, model: NunchakuFluxTransformer2dModel, config):
super(ComfyUIFluxForwardWrapper, self).__init__() super(ComfyUIFluxForwardWrapper, self).__init__()
...@@ -63,7 +59,6 @@ class ComfyUIFluxForwardWrapper(nn.Module): ...@@ -63,7 +59,6 @@ class ComfyUIFluxForwardWrapper(nn.Module):
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w] out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
return out return out
class SVDQuantFluxDiTLoader: class SVDQuantFluxDiTLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
...@@ -89,7 +84,7 @@ class SVDQuantFluxDiTLoader: ...@@ -89,7 +84,7 @@ class SVDQuantFluxDiTLoader:
local_folders.update(local_folders_) local_folders.update(local_folders_)
local_folders = sorted(list(local_folders)) local_folders = sorted(list(local_folders))
model_paths = local_folders + model_paths model_paths = local_folders + model_paths
ngpus = len(GPUtil.getGPUs()) ngpus = torch.cuda.device_count()
return { return {
"required": { "required": {
"model_path": ( "model_path": (
...@@ -97,10 +92,10 @@ class SVDQuantFluxDiTLoader: ...@@ -97,10 +92,10 @@ class SVDQuantFluxDiTLoader:
{"tooltip": "The SVDQuant quantized FLUX.1 models. It can be a huggingface path or a local path."}, {"tooltip": "The SVDQuant quantized FLUX.1 models. It can be a huggingface path or a local path."},
), ),
"cpu_offload": ( "cpu_offload": (
["enable", "disable"], ["auto", "enable", "disable"],
{ {
"default": "disable", "default": "auto",
"tooltip": "Whether to enable CPU offload for the transformer model. This may slow down the inference, but may reduce the GPU memory usage.", "tooltip": "Whether to enable CPU offload for the transformer model. 'auto' will enable it if the GPU memory is less than 14G.",
}, },
), ),
"device_id": ( "device_id": (
...@@ -108,7 +103,7 @@ class SVDQuantFluxDiTLoader: ...@@ -108,7 +103,7 @@ class SVDQuantFluxDiTLoader:
{ {
"default": 0, "default": 0,
"min": 0, "min": 0,
"max": ngpus, "max": ngpus - 1,
"step": 1, "step": 1,
"display": "number", "display": "number",
"lazy": True, "lazy": True,
...@@ -130,7 +125,36 @@ class SVDQuantFluxDiTLoader: ...@@ -130,7 +125,36 @@ class SVDQuantFluxDiTLoader:
if os.path.exists(os.path.join(prefix, model_path)): if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path) model_path = os.path.join(prefix, model_path)
break break
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path, offload=cpu_offload == "enable")
# 验证 device_id 是否有效
if device_id >= torch.cuda.device_count():
raise ValueError(f"Invalid device_id: {device_id}. Only {torch.cuda.device_count()} GPUs available.")
# 获取 ComfyUI 指定 CUDA 设备的显存信息
gpu_properties = torch.cuda.get_device_properties(device_id)
gpu_memory = gpu_properties.total_memory / (1024 ** 2) # 转换为 MB
gpu_name = gpu_properties.name
print(f"GPU {device_id} ({gpu_name}) 显存: {gpu_memory} MB")
# 确定 CPU offload 是否启用
if cpu_offload == "auto":
if gpu_memory < 14336: # 14GB 阈值
cpu_offload_enabled = True
print("因显存小于14GB,启用 CPU offload")
else:
cpu_offload_enabled = False
print("显存大于14GB,不启用 CPU offload")
elif cpu_offload == "enable":
cpu_offload_enabled = True
print("用户启用 CPU offload")
else:
cpu_offload_enabled = False
print("用户禁用 CPU offload")
# 清理 GPU 缓存
# torch.cuda.empty_cache()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path, offload=cpu_offload_enabled)
transformer = transformer.to(device) transformer = transformer.to(device)
dit_config = { dit_config = {
"image_model": "flux", "image_model": "flux",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment