Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
35a4d011
Commit
35a4d011
authored
Feb 13, 2025
by
April Hu
Browse files
Add support to enable flux.1 tools in ComfyUI with int4
parent
50139c73
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
127 additions
and
20 deletions
+127
-20
comfyui/nodes.py
comfyui/nodes.py
+127
-20
No files found.
comfyui/nodes.py
View file @
35a4d011
...
@@ -7,16 +7,17 @@ import comfy.sd
...
@@ -7,16 +7,17 @@ import comfy.sd
import
folder_paths
import
folder_paths
import
GPUtil
import
GPUtil
import
torch
import
torch
import
numpy
as
np
from
comfy.ldm.common_dit
import
pad_to_patch_size
from
comfy.ldm.common_dit
import
pad_to_patch_size
from
comfy.supported_models
import
Flux
,
FluxSchnell
from
comfy.supported_models
import
Flux
,
FluxSchnell
from
diffusers
import
FluxTransformer2DModel
from
diffusers
import
FluxTransformer2DModel
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
torch
import
nn
from
torch
import
nn
from
transformers
import
T5EncoderModel
from
transformers
import
T5EncoderModel
from
image_gen_aux
import
DepthPreprocessor
from
nunchaku.models.transformer_flux
import
NunchakuFluxTransformer2dModel
from
nunchaku.models.transformer_flux
import
NunchakuFluxTransformer2dModel
class
ComfyUIFluxForwardWrapper
(
nn
.
Module
):
class
ComfyUIFluxForwardWrapper
(
nn
.
Module
):
def
__init__
(
self
,
model
:
NunchakuFluxTransformer2dModel
,
config
):
def
__init__
(
self
,
model
:
NunchakuFluxTransformer2dModel
,
config
):
super
(
ComfyUIFluxForwardWrapper
,
self
).
__init__
()
super
(
ComfyUIFluxForwardWrapper
,
self
).
__init__
()
...
@@ -24,13 +25,25 @@ class ComfyUIFluxForwardWrapper(nn.Module):
...
@@ -24,13 +25,25 @@ class ComfyUIFluxForwardWrapper(nn.Module):
self
.
dtype
=
next
(
model
.
parameters
()).
dtype
self
.
dtype
=
next
(
model
.
parameters
()).
dtype
self
.
config
=
config
self
.
config
=
config
def
forward
(
self
,
x
,
timestep
,
context
,
y
,
guidance
,
control
=
None
,
transformer_options
=
{},
**
kwargs
):
def
forward
(
self
,
x
,
timestep
,
context
,
y
,
guidance
,
control
=
None
,
transformer_options
=
{},
**
kwargs
,
):
assert
control
is
None
# for now
assert
control
is
None
# for now
bs
,
c
,
h
,
w
=
x
.
shape
bs
,
c
,
h
,
w
=
x
.
shape
patch_size
=
self
.
config
[
"patch_size"
]
patch_size
=
self
.
config
[
"patch_size"
]
x
=
pad_to_patch_size
(
x
,
(
patch_size
,
patch_size
))
x
=
pad_to_patch_size
(
x
,
(
patch_size
,
patch_size
))
img
=
rearrange
(
x
,
"b c (h ph) (w pw) -> b (h w) (c ph pw)"
,
ph
=
patch_size
,
pw
=
patch_size
)
img
=
rearrange
(
x
,
"b c (h ph) (w pw) -> b (h w) (c ph pw)"
,
ph
=
patch_size
,
pw
=
patch_size
)
h_len
=
(
h
+
(
patch_size
//
2
))
//
patch_size
h_len
=
(
h
+
(
patch_size
//
2
))
//
patch_size
w_len
=
(
w
+
(
patch_size
//
2
))
//
patch_size
w_len
=
(
w
+
(
patch_size
//
2
))
//
patch_size
...
@@ -54,21 +67,30 @@ class ComfyUIFluxForwardWrapper(nn.Module):
...
@@ -54,21 +67,30 @@ class ComfyUIFluxForwardWrapper(nn.Module):
guidance
=
guidance
if
self
.
config
[
"guidance_embed"
]
else
None
,
guidance
=
guidance
if
self
.
config
[
"guidance_embed"
]
else
None
,
).
sample
).
sample
out
=
rearrange
(
out
,
"b (h w) (c ph pw) -> b c (h ph) (w pw)"
,
h
=
h_len
,
w
=
w_len
,
ph
=
2
,
pw
=
2
)[:,
:,
:
h
,
:
w
]
out
=
rearrange
(
out
,
"b (h w) (c ph pw) -> b c (h ph) (w pw)"
,
h
=
h_len
,
w
=
w_len
,
ph
=
2
,
pw
=
2
)[:,
:,
:
h
,
:
w
]
return
out
return
out
class
SVDQuantFluxDiTLoader
:
class
SVDQuantFluxDiTLoader
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
model_paths
=
[
"mit-han-lab/svdq-int4-flux.1-schnell"
,
"mit-han-lab/svdq-int4-flux.1-dev"
]
model_paths
=
[
"mit-han-lab/svdq-int4-flux.1-schnell"
,
"mit-han-lab/svdq-int4-flux.1-dev"
,
"mit-han-lab/svdq-int4-flux.1-canny-dev"
,
"mit-han-lab/svdq-int4-flux.1-depth-dev"
,
"mit-han-lab/svdq-int4-flux.1-fill-dev"
,
]
prefix
=
"models/diffusion_models"
prefix
=
"models/diffusion_models"
local_folders
=
os
.
listdir
(
prefix
)
local_folders
=
os
.
listdir
(
prefix
)
local_folders
=
sorted
(
local_folders
=
sorted
(
[
[
folder
folder
for
folder
in
local_folders
for
folder
in
local_folders
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
]
]
)
)
model_paths
.
extend
(
local_folders
)
model_paths
.
extend
(
local_folders
)
...
@@ -78,7 +100,14 @@ class SVDQuantFluxDiTLoader:
...
@@ -78,7 +100,14 @@ class SVDQuantFluxDiTLoader:
"model_path"
:
(
model_paths
,),
"model_path"
:
(
model_paths
,),
"device_id"
:
(
"device_id"
:
(
"INT"
,
"INT"
,
{
"default"
:
0
,
"min"
:
0
,
"max"
:
ngpus
,
"step"
:
1
,
"display"
:
"number"
,
"lazy"
:
True
},
{
"default"
:
0
,
"min"
:
0
,
"max"
:
ngpus
,
"step"
:
1
,
"display"
:
"number"
,
"lazy"
:
True
,
},
),
),
}
}
}
}
...
@@ -88,17 +117,20 @@ class SVDQuantFluxDiTLoader:
...
@@ -88,17 +117,20 @@ class SVDQuantFluxDiTLoader:
CATEGORY
=
"SVDQuant"
CATEGORY
=
"SVDQuant"
TITLE
=
"SVDQuant Flux DiT Loader"
TITLE
=
"SVDQuant Flux DiT Loader"
def
load_model
(
self
,
model_path
:
str
,
device_id
:
int
,
**
kwargs
)
->
tuple
[
FluxTransformer2DModel
]:
def
load_model
(
self
,
model_path
:
str
,
device_id
:
int
,
**
kwargs
)
->
tuple
[
FluxTransformer2DModel
]:
device
=
f
"cuda:
{
device_id
}
"
device
=
f
"cuda:
{
device_id
}
"
prefix
=
"models/diffusion_models"
prefix
=
"models/diffusion_models"
if
os
.
path
.
exists
(
os
.
path
.
join
(
prefix
,
model_path
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
prefix
,
model_path
)):
model_path
=
os
.
path
.
join
(
prefix
,
model_path
)
model_path
=
os
.
path
.
join
(
prefix
,
model_path
)
else
:
else
:
model_path
=
model_path
model_path
=
model_path
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
model_path
).
to
(
device
)
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
model_path
).
to
(
device
)
dit_config
=
{
dit_config
=
{
"image_model"
:
"flux"
,
"image_model"
:
"flux"
,
"in_channels"
:
16
,
"patch_size"
:
2
,
"patch_size"
:
2
,
"out_channels"
:
16
,
"out_channels"
:
16
,
"vec_in_dim"
:
768
,
"vec_in_dim"
:
768
,
...
@@ -111,21 +143,34 @@ class SVDQuantFluxDiTLoader:
...
@@ -111,21 +143,34 @@ class SVDQuantFluxDiTLoader:
"axes_dim"
:
[
16
,
56
,
56
],
"axes_dim"
:
[
16
,
56
,
56
],
"theta"
:
10000
,
"theta"
:
10000
,
"qkv_bias"
:
True
,
"qkv_bias"
:
True
,
"guidance_embed"
:
True
,
"disable_unet_model_creation"
:
True
,
"disable_unet_model_creation"
:
True
,
}
}
if
"schnell"
in
model_path
:
if
"schnell"
in
model_path
:
dit_config
[
"guidance_embed"
]
=
False
dit_config
[
"guidance_embed"
]
=
False
dit_config
[
"in_channels"
]
=
16
model_config
=
FluxSchnell
(
dit_config
)
model_config
=
FluxSchnell
(
dit_config
)
elif
"canny"
in
model_path
or
"depth"
in
model_path
:
dit_config
[
"in_channels"
]
=
32
model_config
=
Flux
(
dit_config
)
elif
"fill"
in
model_path
:
dit_config
[
"in_channels"
]
=
64
model_config
=
Flux
(
dit_config
)
else
:
else
:
assert
"dev"
in
model_path
assert
(
dit_config
[
"guidance_embed"
]
=
True
model_path
==
"mit-han-lab/svdq-int4-flux.1-dev"
),
f
"model
{
model_path
}
not supported"
dit_config
[
"in_channels"
]
=
16
model_config
=
Flux
(
dit_config
)
model_config
=
Flux
(
dit_config
)
model_config
.
set_inference_dtype
(
torch
.
bfloat16
,
None
)
model_config
.
set_inference_dtype
(
torch
.
bfloat16
,
None
)
model_config
.
custom_operations
=
None
model_config
.
custom_operations
=
None
model
=
model_config
.
get_model
({})
model
=
model_config
.
get_model
({})
model
.
diffusion_model
=
ComfyUIFluxForwardWrapper
(
transformer
,
config
=
dit_config
)
model
.
diffusion_model
=
ComfyUIFluxForwardWrapper
(
transformer
,
config
=
dit_config
)
model
=
comfy
.
model_patcher
.
ModelPatcher
(
model
,
device
,
device_id
)
model
=
comfy
.
model_patcher
.
ModelPatcher
(
model
,
device
,
device_id
)
return
(
model
,)
return
(
model
,)
...
@@ -157,7 +202,8 @@ class SVDQuantTextEncoderLoader:
...
@@ -157,7 +202,8 @@ class SVDQuantTextEncoderLoader:
[
[
folder
folder
for
folder
in
local_folders
for
folder
in
local_folders
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
]
]
)
)
model_paths
.
extend
(
local_folders
)
model_paths
.
extend
(
local_folders
)
...
@@ -168,7 +214,14 @@ class SVDQuantTextEncoderLoader:
...
@@ -168,7 +214,14 @@ class SVDQuantTextEncoderLoader:
"text_encoder2"
:
(
folder_paths
.
get_filename_list
(
"text_encoders"
),),
"text_encoder2"
:
(
folder_paths
.
get_filename_list
(
"text_encoders"
),),
"t5_min_length"
:
(
"t5_min_length"
:
(
"INT"
,
"INT"
,
{
"default"
:
512
,
"min"
:
256
,
"max"
:
1024
,
"step"
:
128
,
"display"
:
"number"
,
"lazy"
:
True
},
{
"default"
:
512
,
"min"
:
256
,
"max"
:
1024
,
"step"
:
128
,
"display"
:
"number"
,
"lazy"
:
True
,
},
),
),
"t5_precision"
:
([
"BF16"
,
"INT4"
],),
"t5_precision"
:
([
"BF16"
,
"INT4"
],),
"int4_model"
:
(
model_paths
,
{
"tooltip"
:
"The name of the INT4 model."
}),
"int4_model"
:
(
model_paths
,
{
"tooltip"
:
"The name of the INT4 model."
}),
...
@@ -191,8 +244,12 @@ class SVDQuantTextEncoderLoader:
...
@@ -191,8 +244,12 @@ class SVDQuantTextEncoderLoader:
t5_precision
:
str
,
t5_precision
:
str
,
int4_model
:
str
,
int4_model
:
str
,
):
):
text_encoder_path1
=
folder_paths
.
get_full_path_or_raise
(
"text_encoders"
,
text_encoder1
)
text_encoder_path1
=
folder_paths
.
get_full_path_or_raise
(
text_encoder_path2
=
folder_paths
.
get_full_path_or_raise
(
"text_encoders"
,
text_encoder2
)
"text_encoders"
,
text_encoder1
)
text_encoder_path2
=
folder_paths
.
get_full_path_or_raise
(
"text_encoders"
,
text_encoder2
)
if
model_type
==
"flux"
:
if
model_type
==
"flux"
:
clip_type
=
comfy
.
sd
.
CLIPType
.
FLUX
clip_type
=
comfy
.
sd
.
CLIPType
.
FLUX
else
:
else
:
...
@@ -223,7 +280,9 @@ class SVDQuantTextEncoderLoader:
...
@@ -223,7 +280,9 @@ class SVDQuantTextEncoderLoader:
transformer
=
NunchakuT5EncoderModel
.
from_pretrained
(
model_path
)
transformer
=
NunchakuT5EncoderModel
.
from_pretrained
(
model_path
)
transformer
.
forward
=
types
.
MethodType
(
svdquant_t5_forward
,
transformer
)
transformer
.
forward
=
types
.
MethodType
(
svdquant_t5_forward
,
transformer
)
clip
.
cond_stage_model
.
t5xxl
.
transformer
=
(
clip
.
cond_stage_model
.
t5xxl
.
transformer
=
(
transformer
.
to
(
device
=
device
,
dtype
=
dtype
)
if
device
.
type
==
"cuda"
else
transformer
transformer
.
to
(
device
=
device
,
dtype
=
dtype
)
if
device
.
type
==
"cuda"
else
transformer
)
)
return
(
clip
,)
return
(
clip
,)
...
@@ -239,11 +298,17 @@ class SVDQuantLoraLoader:
...
@@ -239,11 +298,17 @@ class SVDQuantLoraLoader:
lora_name_list
=
[
lora_name_list
=
[
"None"
,
"None"
,
*
folder_paths
.
get_filename_list
(
"loras"
),
*
folder_paths
.
get_filename_list
(
"loras"
),
*
[
f
"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-
{
n
}
.safetensors"
for
n
in
hf_lora_names
],
*
[
f
"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-
{
n
}
.safetensors"
for
n
in
hf_lora_names
],
]
]
return
{
return
{
"required"
:
{
"required"
:
{
"model"
:
(
"MODEL"
,
{
"tooltip"
:
"The diffusion model the LoRA will be applied to."
}),
"model"
:
(
"MODEL"
,
{
"tooltip"
:
"The diffusion model the LoRA will be applied to."
},
),
"lora_name"
:
(
lora_name_list
,
{
"tooltip"
:
"The name of the LoRA."
}),
"lora_name"
:
(
lora_name_list
,
{
"tooltip"
:
"The name of the LoRA."
}),
"lora_strength"
:
(
"lora_strength"
:
(
"FLOAT"
,
"FLOAT"
,
...
@@ -292,8 +357,50 @@ class SVDQuantLoraLoader:
...
@@ -292,8 +357,50 @@ class SVDQuantLoraLoader:
return
(
model
,)
return
(
model
,)
class
DepthPreprocesser
:
@
classmethod
def
INPUT_TYPES
(
s
):
model_paths
=
[
"LiheYoung/depth-anything-large-hf"
]
prefix
=
"models/style_models"
local_folders
=
os
.
listdir
(
prefix
)
local_folders
=
sorted
(
[
folder
for
folder
in
local_folders
if
not
folder
.
startswith
(
"."
)
and
os
.
path
.
isdir
(
os
.
path
.
join
(
prefix
,
folder
))
]
)
model_paths
.
extend
(
local_folders
)
return
{
"required"
:
{
"image"
:
(
"IMAGE"
,
{}),
"model_path"
:
(
model_paths
,
{
"tooltip"
:
"Name of the depth preprocesser model."
},
),
}
}
RETURN_TYPES
=
(
"IMAGE"
,)
FUNCTION
=
"depth_preprocess"
CATEGORY
=
"Flux.1"
TITLE
=
"Flux.1 Depth Preprocessor"
def
depth_preprocess
(
self
,
image
,
model_path
):
prefix
=
"models/style_models"
if
os
.
path
.
exists
(
os
.
path
.
join
(
prefix
,
model_path
)):
model_path
=
os
.
path
.
join
(
prefix
,
model_path
)
processor
=
DepthPreprocessor
.
from_pretrained
(
model_path
)
np_image
=
np
.
asarray
(
image
)
np_result
=
np
.
array
(
processor
(
np_image
)[
0
].
convert
(
"RGB"
))
out_tensor
=
torch
.
from_numpy
(
np_result
.
astype
(
np
.
float32
)
/
255.0
).
unsqueeze
(
0
)
return
(
out_tensor
,)
NODE_CLASS_MAPPINGS
=
{
NODE_CLASS_MAPPINGS
=
{
"SVDQuantFluxDiTLoader"
:
SVDQuantFluxDiTLoader
,
"SVDQuantFluxDiTLoader"
:
SVDQuantFluxDiTLoader
,
"SVDQuantTextEncoderLoader"
:
SVDQuantTextEncoderLoader
,
"SVDQuantTextEncoderLoader"
:
SVDQuantTextEncoderLoader
,
"SVDQuantLoRALoader"
:
SVDQuantLoraLoader
,
"SVDQuantLoRALoader"
:
SVDQuantLoraLoader
,
"DepthPreprocesser"
:
DepthPreprocesser
,
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment