Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
742a8006
Commit
742a8006
authored
Mar 11, 2025
by
Muyang Li
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
[major] Fix the tempfile bug in the comfyui
parent
27232e7b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
97 additions
and
113 deletions
+97
-113
comfyui/nodes/lora/flux.py
comfyui/nodes/lora/flux.py
+48
-29
comfyui/nodes/models/flux.py
comfyui/nodes/models/flux.py
+22
-26
comfyui/nodes/models/text_encoder.py
comfyui/nodes/models/text_encoder.py
+5
-7
comfyui/nodes/preprocessors/depth.py
comfyui/nodes/preprocessors/depth.py
+8
-5
examples/int4-flux.1-dev-qencoder-offload.py
examples/int4-flux.1-dev-qencoder-offload.py
+0
-15
examples/int4-flux.1-schnell-qencoder-offload.py
examples/int4-flux.1-schnell-qencoder-offload.py
+0
-20
nunchaku/models/text_encoders/t5_encoder.py
nunchaku/models/text_encoders/t5_encoder.py
+6
-6
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+8
-5
No files found.
comfyui/nodes/lora/flux.py
View file @
742a8006
import
logging
import
os
import
tempfile
import
folder_paths
from
safetensors.torch
import
save_file
from
nunchaku.lora.flux
import
comfyui2diffusers
,
convert_to_nunchaku_flux_lowrank_dict
,
detect_format
,
xlab2diffusers
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
"SVDQuantFluxLoraLoader"
)
class
SVDQuantFluxLoraLoader
:
def
__init__
(
self
):
...
...
@@ -13,31 +16,20 @@ class SVDQuantFluxLoraLoader:
@
classmethod
def
INPUT_TYPES
(
s
):
lora_name_list
=
[
"None"
,
*
folder_paths
.
get_filename_list
(
"loras"
),
"aleksa-codes/flux-ghibsky-illustration/lora.safetensors"
,
]
lora_name_list
=
[
"None"
,
*
folder_paths
.
get_filename_list
(
"loras"
)]
base_model_paths
=
[
"mit-han-lab/svdq-int4-flux.1-dev"
,
"mit-han-lab/svdq-int4-flux.1-schnell"
,
"mit-han-lab/svdq-fp4-flux.1-dev"
,
"mit-han-lab/svdq-fp4-flux.1-schnell"
,
"mit-han-lab/svdq-int4-flux.1-canny-dev"
,
"mit-han-lab/svdq-int4-flux.1-depth-dev"
,
"mit-han-lab/svdq-int4-flux.1-fill-dev"
,
]
prefix
=
os
.
path
.
join
(
folder_paths
.
models_dir
,
"diffusion_models"
)
local_base_model_folders
=
os
.
listdir
(
prefix
)
local_base_model_folders
=
sorted
(
[
folder
for
folder
in
local_base_model_folders
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
]
)
base_model_paths
=
local_base_model_folders
+
base_model_paths
prefixes
=
folder_paths
.
folder_names_and_paths
[
"diffusion_models"
][
0
]
base_model_paths
=
set
()
for
prefix
in
prefixes
:
if
os
.
path
.
exists
(
prefix
)
and
os
.
path
.
isdir
(
prefix
):
base_model_paths_
=
os
.
listdir
(
prefix
)
base_model_paths_
=
[
folder
for
folder
in
base_model_paths_
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
]
base_model_paths
.
update
(
base_model_paths_
)
base_model_paths
=
sorted
(
list
(
base_model_paths
))
return
{
"required"
:
{
...
...
@@ -63,6 +55,12 @@ class SVDQuantFluxLoraLoader:
"tooltip"
:
"How strongly to modify the diffusion model. This value can be negative."
,
},
),
"save_converted_lora"
:
(
[
"disable"
,
"enable"
],
{
"tooltip"
:
"If enabled, the converted LoRA will be saved as a .safetensors file in the save directory of your LoRA file."
},
),
}
}
...
...
@@ -78,7 +76,15 @@ class SVDQuantFluxLoraLoader:
"Currently, only one LoRA nodes can be applied."
)
def
load_lora
(
self
,
model
,
lora_name
:
str
,
lora_format
:
str
,
base_model_name
:
str
,
lora_strength
:
float
):
def
load_lora
(
self
,
model
,
lora_name
:
str
,
lora_format
:
str
,
base_model_name
:
str
,
lora_strength
:
float
,
save_converted_lora
:
str
,
):
if
self
.
cur_lora_name
==
lora_name
:
if
self
.
cur_lora_name
==
"None"
:
pass
# Do nothing since the lora is None
...
...
@@ -110,9 +116,22 @@ class SVDQuantFluxLoraLoader:
base_model_path
=
os
.
path
.
join
(
base_model_name
,
"transformer_blocks.safetensors"
)
state_dict
=
convert_to_nunchaku_flux_lowrank_dict
(
base_model_path
,
input_lora
)
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".safetensors"
,
delete
=
True
)
as
tmp_file
:
save_file
(
state_dict
,
tmp_file
.
name
)
model
.
model
.
diffusion_model
.
model
.
update_lora_params
(
tmp_file
.
name
)
if
save_converted_lora
==
"enable"
and
lora_format
!=
"svdquant"
:
dirname
=
os
.
path
.
dirname
(
lora_path
)
basename
=
os
.
path
.
basename
(
lora_path
)
if
"int4"
in
base_model_path
:
precision
=
"int4"
else
:
assert
"fp4"
in
base_model_path
precision
=
"fp4"
converted_name
=
f
"svdq-
{
precision
}
-
{
basename
}
"
lora_converted_path
=
os
.
path
.
join
(
dirname
,
converted_name
)
if
not
os
.
path
.
exists
(
lora_converted_path
):
save_file
(
state_dict
,
lora_converted_path
)
logger
.
info
(
f
"Saved converted LoRA to:
{
lora_converted_path
}
"
)
else
:
logger
.
info
(
f
"Converted LoRA already exists at:
{
lora_converted_path
}
"
)
model
.
model
.
diffusion_model
.
model
.
update_lora_params
(
state_dict
)
else
:
model
.
model
.
diffusion_model
.
model
.
update_lora_params
(
lora_path
)
model
.
model
.
diffusion_model
.
model
.
set_lora_strength
(
lora_strength
)
...
...
comfyui/nodes/models/flux.py
View file @
742a8006
import
os
import
comfy.model_patcher
import
folder_paths
import
torch
...
...
@@ -7,8 +8,10 @@ from comfy.supported_models import Flux, FluxSchnell
from
diffusers
import
FluxTransformer2DModel
from
einops
import
rearrange
,
repeat
from
torch
import
nn
from
nunchaku
import
NunchakuFluxTransformer2dModel
class
ComfyUIFluxForwardWrapper
(
nn
.
Module
):
def
__init__
(
self
,
model
:
NunchakuFluxTransformer2dModel
,
config
):
super
(
ComfyUIFluxForwardWrapper
,
self
).
__init__
()
...
...
@@ -59,18 +62,10 @@ class ComfyUIFluxForwardWrapper(nn.Module):
out
=
rearrange
(
out
,
"b (h w) (c ph pw) -> b c (h ph) (w pw)"
,
h
=
h_len
,
w
=
w_len
,
ph
=
2
,
pw
=
2
)[:,
:,
:
h
,
:
w
]
return
out
class
SVDQuantFluxDiTLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
model_paths
=
[
"mit-han-lab/svdq-int4-flux.1-schnell"
,
"mit-han-lab/svdq-int4-flux.1-dev"
,
"mit-han-lab/svdq-fp4-flux.1-schnell"
,
"mit-han-lab/svdq-fp4-flux.1-dev"
,
"mit-han-lab/svdq-int4-flux.1-canny-dev"
,
"mit-han-lab/svdq-int4-flux.1-depth-dev"
,
"mit-han-lab/svdq-int4-flux.1-fill-dev"
,
]
prefixes
=
folder_paths
.
folder_names_and_paths
[
"diffusion_models"
][
0
]
local_folders
=
set
()
for
prefix
in
prefixes
:
...
...
@@ -82,8 +77,7 @@ class SVDQuantFluxDiTLoader:
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
]
local_folders
.
update
(
local_folders_
)
local_folders
=
sorted
(
list
(
local_folders
))
model_paths
=
local_folders
+
model_paths
model_paths
=
sorted
(
list
(
local_folders
))
ngpus
=
torch
.
cuda
.
device_count
()
return
{
"required"
:
{
...
...
@@ -126,35 +120,37 @@ class SVDQuantFluxDiTLoader:
model_path
=
os
.
path
.
join
(
prefix
,
model_path
)
break
#
验证
device_id
是否有效
#
Check if the
device_id
is valid
if
device_id
>=
torch
.
cuda
.
device_count
():
raise
ValueError
(
f
"Invalid device_id:
{
device_id
}
. Only
{
torch
.
cuda
.
device_count
()
}
GPUs available."
)
#
获取 ComfyUI 指定 CUDA 设备的显存信息
#
Get the GPU properties
gpu_properties
=
torch
.
cuda
.
get_device_properties
(
device_id
)
gpu_memory
=
gpu_properties
.
total_memory
/
(
1024
**
2
)
#
转换为
MB
gpu_memory
=
gpu_properties
.
total_memory
/
(
1024
**
2
)
#
Convert to
MB
gpu_name
=
gpu_properties
.
name
print
(
f
"GPU
{
device_id
}
(
{
gpu_name
}
)
显存
:
{
gpu_memory
}
MB"
)
print
(
f
"GPU
{
device_id
}
(
{
gpu_name
}
)
Memory
:
{
gpu_memory
}
MB"
)
#
确定
CPU offload
是否启用
#
Check if
CPU offload
needs to be enabled
if
cpu_offload
==
"auto"
:
if
gpu_memory
<
14336
:
# 14GB
阈值
if
gpu_memory
<
14336
:
# 14GB
threshold
cpu_offload_enabled
=
True
print
(
"
因显存小于
14GB,
启用
CPU offload"
)
print
(
"
VRAM <
14G
i
B,
enable
CPU offload"
)
else
:
cpu_offload_enabled
=
False
print
(
"
显存大于
14GB,
不启用
CPU offload"
)
print
(
"
VRAM >
14G
i
B,
disable
CPU offload"
)
elif
cpu_offload
==
"enable"
:
cpu_offload_enabled
=
True
print
(
"
用户启用
CPU offload"
)
print
(
"
Enable
CPU offload"
)
else
:
cpu_offload_enabled
=
False
print
(
"用户禁用 CPU offload"
)
# 清理 GPU 缓存
# torch.cuda.empty_cache()
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
model_path
,
offload
=
cpu_offload_enabled
)
print
(
"Disable CPU offload"
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
model_path
,
precision
=
precision
,
offload
=
cpu_offload_enabled
)
transformer
=
transformer
.
to
(
device
)
dit_config
=
{
"image_model"
:
"flux"
,
...
...
comfyui/nodes/models/text_encoder.py
View file @
742a8006
...
...
@@ -45,7 +45,6 @@ class WrappedEmbedding(nn.Module):
class
SVDQuantTextEncoderLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
model_paths
=
[
"mit-han-lab/svdq-flux.1-t5"
]
prefixes
=
folder_paths
.
folder_names_and_paths
[
"text_encoders"
][
0
]
local_folders
=
set
()
for
prefix
in
prefixes
:
...
...
@@ -57,8 +56,7 @@ class SVDQuantTextEncoderLoader:
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
]
local_folders
.
update
(
local_folders_
)
local_folders
=
sorted
(
list
(
local_folders
))
model_paths
.
extend
(
local_folders
)
model_paths
=
sorted
(
list
(
local_folders
))
return
{
"required"
:
{
"model_type"
:
([
"flux"
],),
...
...
@@ -68,8 +66,8 @@ class SVDQuantTextEncoderLoader:
"INT"
,
{
"default"
:
512
,
"min"
:
256
,
"max"
:
1024
,
"step"
:
128
,
"display"
:
"number"
,
"lazy"
:
True
},
),
"
t5_precision"
:
([
"BF16"
,
"INT4
"
],),
"int4_model"
:
(
model_paths
,
{
"tooltip"
:
"The name of the
INT4
model."
}),
"
use_4bit_t5"
:
([
"disable"
,
"enable
"
],),
"int4_model"
:
(
model_paths
,
{
"tooltip"
:
"The name of the
4-bit T5
model."
}),
}
}
...
...
@@ -86,7 +84,7 @@ class SVDQuantTextEncoderLoader:
text_encoder1
:
str
,
text_encoder2
:
str
,
t5_min_length
:
int
,
t5_precision
:
str
,
use_4bit_t5
:
str
,
int4_model
:
str
,
):
text_encoder_path1
=
folder_paths
.
get_full_path_or_raise
(
"text_encoders"
,
text_encoder1
)
...
...
@@ -105,7 +103,7 @@ class SVDQuantTextEncoderLoader:
if
model_type
==
"flux"
:
clip
.
tokenizer
.
t5xxl
.
min_length
=
t5_min_length
if
t5_precision
==
"INT4
"
:
if
use_4bit_t5
==
"enable
"
:
transformer
=
clip
.
cond_stage_model
.
t5xxl
.
transformer
param
=
next
(
transformer
.
parameters
())
dtype
=
param
.
dtype
...
...
comfyui/nodes/preprocessors/depth.py
View file @
742a8006
...
...
@@ -3,13 +3,12 @@ import os
import
folder_paths
import
numpy
as
np
import
torch
from
image_gen_aux
import
DepthPreprocessor
class
FluxDepthPreprocessor
:
@
classmethod
def
INPUT_TYPES
(
s
):
model_paths
=
[
"LiheYoung/depth-anything-large-hf"
]
model_paths
=
[]
prefix
=
os
.
path
.
join
(
folder_paths
.
models_dir
,
"checkpoints"
)
local_folders
=
os
.
listdir
(
prefix
)
local_folders
=
sorted
(
...
...
@@ -36,9 +35,13 @@ class FluxDepthPreprocessor:
TITLE
=
"FLUX.1 Depth Preprocessor"
def
depth_preprocess
(
self
,
image
,
model_path
):
prefix
=
os
.
path
.
join
(
folder_paths
.
models_dir
,
"checkpoints"
)
if
os
.
path
.
exists
(
os
.
path
.
join
(
prefix
,
model_path
)):
model_path
=
os
.
path
.
join
(
prefix
,
model_path
)
prefixes
=
folder_paths
.
folder_names_and_paths
[
"checkpoints"
][
0
]
for
prefix
in
prefixes
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
prefix
,
model_path
)):
model_path
=
os
.
path
.
join
(
prefix
,
model_path
)
break
from
image_gen_aux
import
DepthPreprocessor
processor
=
DepthPreprocessor
.
from_pretrained
(
model_path
)
np_image
=
np
.
asarray
(
image
)
np_result
=
np
.
array
(
processor
(
np_image
)[
0
].
convert
(
"RGB"
))
...
...
examples/int4-flux.1-dev-qencoder-offload.py
deleted
100644 → 0
View file @
27232e7b
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
,
offload
=
True
)
# set offload to False if you want to disable offloading
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
text_encoder_2
=
text_encoder_2
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
pipeline
.
enable_sequential_cpu_offload
()
# remove this line if you want to disable the CPU offloading
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
num_inference_steps
=
50
,
guidance_scale
=
3.5
).
images
[
0
]
image
.
save
(
"flux.1-dev.png"
)
examples/int4-flux.1-schnell-qencoder-offload.py
deleted
100644 → 0
View file @
27232e7b
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
,
offload
=
True
)
# set offload to False if you want to disable offloading
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
text_encoder_2
=
text_encoder_2
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
"cuda"
)
pipeline
.
enable_sequential_cpu_offload
()
# remove this line if you want to disable the CPU offloading
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
image
.
save
(
"flux.1-schnell.png"
)
nunchaku/models/text_encoders/t5_encoder.py
View file @
742a8006
...
...
@@ -60,12 +60,12 @@ def quantize_t5_encoder(
if
isinstance
(
module
,
nn
.
Linear
):
if
f
"
{
name
}
.qweight"
in
state_dict
and
name
.
endswith
(
qlayer_suffix
):
print
(
f
"Switching
{
name
}
to W4Linear"
)
qmodule
=
W4Linear
.
from_linear
(
module
,
group_size
=
128
,
init_only
=
Tru
e
)
qmodule
.
qweight
.
data
.
copy_
(
state_dict
[
f
"
{
name
}
.qweight"
])
if
qmodule
.
bias
is
not
None
:
qmodule
.
bias
.
data
.
copy_
(
state_dict
[
f
"
{
name
}
.bias"
])
qmodule
.
scales
.
data
.
copy_
(
state_dict
[
f
"
{
name
}
.scales"
])
qmodule
.
scaled_zeros
.
data
.
copy_
(
state_dict
[
f
"
{
name
}
.scaled_zeros"
])
qmodule
=
W4Linear
.
from_linear
(
module
,
group_size
=
128
,
init_only
=
Fals
e
)
#
qmodule.qweight.data.copy_(state_dict[f"{name}.qweight"])
#
if qmodule.bias is not None:
#
qmodule.bias.data.copy_(state_dict[f"{name}.bias"])
#
qmodule.scales.data.copy_(state_dict[f"{name}.scales"])
#
qmodule.scaled_zeros.data.copy_(state_dict[f"{name}.scaled_zeros"])
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule
.
weight
=
torch
.
empty
([
1
],
dtype
=
module
.
weight
.
dtype
,
device
=
module
.
weight
.
device
)
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
742a8006
...
...
@@ -8,7 +8,6 @@ from huggingface_hub import utils
from
packaging.version
import
Version
from
torch
import
nn
from
nunchaku.utils
import
fetch_or_download
from
.utils
import
NunchakuModelLoaderMixin
,
pad_tensor
from
..._C
import
QuantizedFluxModel
,
utils
as
cutils
from
...utils
import
load_state_dict_in_safetensors
...
...
@@ -224,13 +223,18 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
new_state_dict
[
k
]
=
v
self
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
def
update_lora_params
(
self
,
path
:
str
):
state_dict
=
load_state_dict_in_safetensors
(
path
)
def
update_lora_params
(
self
,
path_or_state_dict
:
str
|
dict
[
str
,
torch
.
Tensor
]):
if
isinstance
(
path_or_state_dict
,
dict
):
state_dict
=
path_or_state_dict
else
:
state_dict
=
load_state_dict_in_safetensors
(
path_or_state_dict
)
unquantized_loras
=
{}
for
k
in
state_dict
.
keys
():
if
"transformer_blocks"
not
in
k
:
unquantized_loras
[
k
]
=
state_dict
[
k
]
for
k
in
unquantized_loras
.
keys
():
state_dict
.
pop
(
k
)
self
.
unquantized_loras
=
unquantized_loras
if
len
(
unquantized_loras
)
>
0
:
...
...
@@ -239,10 +243,9 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self
.
unquantized_state_dict
=
{
k
:
v
.
cpu
()
for
k
,
v
in
unquantized_state_dict
.
items
()}
self
.
update_unquantized_lora_params
(
1
)
path
=
fetch_or_download
(
path
)
block
=
self
.
transformer_blocks
[
0
]
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
block
.
m
.
load
(
path
,
True
)
block
.
m
.
load
Dict
(
path
_or_state_dict
,
True
)
def
set_lora_strength
(
self
,
strength
:
float
=
1
):
block
=
self
.
transformer_blocks
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment