Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
07f07563
Unverified
Commit
07f07563
authored
Jun 16, 2025
by
Muyang Li
Committed by
GitHub
Jun 16, 2025
Browse files
chore: release v0.3.1
parents
7214300d
ad92b16a
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
226 additions
and
52 deletions
+226
-52
nunchaku/lora/flux/nunchaku_converter.py
nunchaku/lora/flux/nunchaku_converter.py
+6
-0
nunchaku/models/pulid/eva_clip/factory.py
nunchaku/models/pulid/eva_clip/factory.py
+35
-8
nunchaku/models/pulid/pulid_forward.py
nunchaku/models/pulid/pulid_forward.py
+12
-0
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+5
-5
nunchaku/pipeline/pipeline_flux_pulid.py
nunchaku/pipeline/pipeline_flux_pulid.py
+94
-28
nunchaku/utils.py
nunchaku/utils.py
+9
-0
src/FluxModel.cpp
src/FluxModel.cpp
+21
-11
src/FluxModel.h
src/FluxModel.h
+3
-0
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+38
-0
tests/flux/test_flux_dev_pulid.py
tests/flux/test_flux_dev_pulid.py
+3
-0
No files found.
nunchaku/lora/flux/nunchaku_converter.py
View file @
07f07563
...
@@ -12,8 +12,14 @@ from .diffusers_converter import to_diffusers
...
@@ -12,8 +12,14 @@ from .diffusers_converter import to_diffusers
from
.packer
import
NunchakuWeightPacker
from
.packer
import
NunchakuWeightPacker
from
.utils
import
is_nunchaku_format
,
pad
from
.utils
import
is_nunchaku_format
,
pad
# Get log level from environment variable (default to INFO)
log_level
=
os
.
getenv
(
"LOG_LEVEL"
,
"INFO"
).
upper
()
# Configure logging
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
,
logging
.
INFO
),
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# region utilities
# region utilities
...
...
nunchaku/models/pulid/eva_clip/factory.py
View file @
07f07563
...
@@ -3,11 +3,13 @@ import logging
...
@@ -3,11 +3,13 @@ import logging
import
os
import
os
import
re
import
re
from
copy
import
deepcopy
from
copy
import
deepcopy
from
os
import
PathLike
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
from
....utils
import
fetch_or_download
from
.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
.model
import
CLIP
,
CustomCLIP
,
convert_to_custom_text_state_dict
,
get_cast_dtype
from
.model
import
CLIP
,
CustomCLIP
,
convert_to_custom_text_state_dict
,
get_cast_dtype
from
.pretrained
import
download_pretrained
,
get_pretrained_cfg
,
list_pretrained_tags_by_model
from
.pretrained
import
download_pretrained
,
get_pretrained_cfg
,
list_pretrained_tags_by_model
...
@@ -227,6 +229,7 @@ def create_model(
...
@@ -227,6 +229,7 @@ def create_model(
pretrained_text_model
:
str
=
None
,
pretrained_text_model
:
str
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
skip_list
:
list
=
[],
skip_list
:
list
=
[],
pretrained_path
:
str
|
PathLike
[
str
]
=
"QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"
,
):
):
model_name
=
model_name
.
replace
(
"/"
,
"-"
)
# for callers using old naming with / in ViT names
model_name
=
model_name
.
replace
(
"/"
,
"-"
)
# for callers using old naming with / in ViT names
if
isinstance
(
device
,
str
):
if
isinstance
(
device
,
str
):
...
@@ -239,8 +242,35 @@ def create_model(
...
@@ -239,8 +242,35 @@ def create_model(
if
model_cfg
is
not
None
:
if
model_cfg
is
not
None
:
logging
.
info
(
f
"Loaded
{
model_name
}
model config."
)
logging
.
info
(
f
"Loaded
{
model_name
}
model config."
)
else
:
else
:
logging
.
error
(
f
"Model config for
{
model_name
}
not found; available models
{
list_models
()
}
."
)
model_cfg
=
{
raise
RuntimeError
(
f
"Model config for
{
model_name
}
not found."
)
"embed_dim"
:
768
,
"vision_cfg"
:
{
"image_size"
:
336
,
"layers"
:
24
,
"width"
:
1024
,
"drop_path_rate"
:
0
,
"head_width"
:
64
,
"mlp_ratio"
:
2.6667
,
"patch_size"
:
14
,
"eva_model_name"
:
"eva-clip-l-14-336"
,
"xattn"
:
True
,
"fusedLN"
:
True
,
"rope"
:
True
,
"pt_hw_seq_len"
:
16
,
"intp_freq"
:
True
,
"naiveswiglu"
:
True
,
"subln"
:
True
,
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
768
,
"heads"
:
12
,
"layers"
:
12
,
"xattn"
:
False
,
"fusedLN"
:
True
,
},
}
if
"rope"
in
model_cfg
.
get
(
"vision_cfg"
,
{}):
if
"rope"
in
model_cfg
.
get
(
"vision_cfg"
,
{}):
if
model_cfg
[
"vision_cfg"
][
"rope"
]:
if
model_cfg
[
"vision_cfg"
][
"rope"
]:
...
@@ -270,12 +300,7 @@ def create_model(
...
@@ -270,12 +300,7 @@ def create_model(
pretrained_cfg
=
{}
pretrained_cfg
=
{}
if
pretrained
:
if
pretrained
:
checkpoint_path
=
""
checkpoint_path
=
fetch_or_download
(
pretrained_path
)
pretrained_cfg
=
get_pretrained_cfg
(
model_name
,
pretrained
)
if
pretrained_cfg
:
checkpoint_path
=
download_pretrained
(
pretrained_cfg
,
cache_dir
=
cache_dir
)
elif
os
.
path
.
exists
(
pretrained
):
checkpoint_path
=
pretrained
if
checkpoint_path
:
if
checkpoint_path
:
logging
.
info
(
f
"Loading pretrained
{
model_name
}
weights (
{
pretrained
}
)."
)
logging
.
info
(
f
"Loading pretrained
{
model_name
}
weights (
{
pretrained
}
)."
)
...
@@ -379,6 +404,7 @@ def create_model_and_transforms(
...
@@ -379,6 +404,7 @@ def create_model_and_transforms(
image_std
:
Optional
[
Tuple
[
float
,
...]]
=
None
,
image_std
:
Optional
[
Tuple
[
float
,
...]]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
skip_list
:
list
=
[],
skip_list
:
list
=
[],
pretrained_path
:
str
|
PathLike
[
str
]
=
"QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"
,
):
):
model
=
create_model
(
model
=
create_model
(
model_name
,
model_name
,
...
@@ -396,6 +422,7 @@ def create_model_and_transforms(
...
@@ -396,6 +422,7 @@ def create_model_and_transforms(
pretrained_text_model
=
pretrained_text_model
,
pretrained_text_model
=
pretrained_text_model
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
skip_list
=
skip_list
,
skip_list
=
skip_list
,
pretrained_path
=
pretrained_path
,
)
)
image_mean
=
image_mean
or
getattr
(
model
.
visual
,
"image_mean"
,
None
)
image_mean
=
image_mean
or
getattr
(
model
.
visual
,
"image_mean"
,
None
)
...
...
nunchaku/models/pulid/pulid_forward.py
View file @
07f07563
...
@@ -24,6 +24,8 @@ def pulid_forward(
...
@@ -24,6 +24,8 @@ def pulid_forward(
controlnet_single_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
controlnet_blocks_repeat
:
bool
=
False
,
controlnet_blocks_repeat
:
bool
=
False
,
start_timestep
:
float
|
None
=
None
,
end_timestep
:
float
|
None
=
None
,
)
->
Union
[
torch
.
FloatTensor
,
Transformer2DModelOutput
]:
)
->
Union
[
torch
.
FloatTensor
,
Transformer2DModelOutput
]:
"""
"""
Copied from diffusers.models.flux.transformer_flux.py
Copied from diffusers.models.flux.transformer_flux.py
...
@@ -53,6 +55,16 @@ def pulid_forward(
...
@@ -53,6 +55,16 @@ def pulid_forward(
"""
"""
hidden_states
=
self
.
x_embedder
(
hidden_states
)
hidden_states
=
self
.
x_embedder
(
hidden_states
)
if
timestep
.
numel
()
>
1
:
timestep_float
=
timestep
.
flatten
()[
0
].
item
()
else
:
timestep_float
=
timestep
.
item
()
if
start_timestep
is
not
None
and
start_timestep
>
timestep_float
:
id_embeddings
=
None
if
end_timestep
is
not
None
and
end_timestep
<
timestep_float
:
id_embeddings
=
None
timestep
=
timestep
.
to
(
hidden_states
.
dtype
)
*
1000
timestep
=
timestep
.
to
(
hidden_states
.
dtype
)
*
1000
if
guidance
is
not
None
:
if
guidance
is
not
None
:
guidance
=
guidance
.
to
(
hidden_states
.
dtype
)
*
1000
guidance
=
guidance
.
to
(
hidden_states
.
dtype
)
*
1000
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
07f07563
...
@@ -81,7 +81,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -81,7 +81,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self
.
id_weight
=
id_weight
self
.
id_weight
=
id_weight
self
.
pulid_ca_idx
=
0
self
.
pulid_ca_idx
=
0
if
self
.
id_embeddings
is
not
None
:
if
self
.
id_embeddings
is
not
None
:
self
.
set_residual_callback
()
self
.
set_
pulid_
residual_callback
()
original_dtype
=
hidden_states
.
dtype
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
original_device
=
hidden_states
.
device
...
@@ -129,7 +129,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -129,7 +129,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
)
)
if
self
.
id_embeddings
is
not
None
:
if
self
.
id_embeddings
is
not
None
:
self
.
reset_residual_callback
()
self
.
reset_
pulid_
residual_callback
()
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
...
@@ -194,21 +194,21 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -194,21 +194,21 @@ class NunchakuFluxTransformerBlocks(nn.Module):
return
encoder_hidden_states
,
hidden_states
return
encoder_hidden_states
,
hidden_states
def
set_residual_callback
(
self
):
def
set_
pulid_
residual_callback
(
self
):
id_embeddings
=
self
.
id_embeddings
id_embeddings
=
self
.
id_embeddings
pulid_ca
=
self
.
pulid_ca
pulid_ca
=
self
.
pulid_ca
pulid_ca_idx
=
[
self
.
pulid_ca_idx
]
pulid_ca_idx
=
[
self
.
pulid_ca_idx
]
id_weight
=
self
.
id_weight
id_weight
=
self
.
id_weight
def
callback
(
hidden_states
):
def
callback
(
hidden_states
):
ip
=
id_weight
*
pulid_ca
[
pulid_ca_idx
[
0
]](
id_embeddings
,
hidden_states
.
to
(
"cuda"
)
)
ip
=
id_weight
*
pulid_ca
[
pulid_ca_idx
[
0
]](
id_embeddings
,
hidden_states
)
pulid_ca_idx
[
0
]
+=
1
pulid_ca_idx
[
0
]
+=
1
return
ip
return
ip
self
.
callback_holder
=
callback
self
.
callback_holder
=
callback
self
.
m
.
set_residual_callback
(
callback
)
self
.
m
.
set_residual_callback
(
callback
)
def
reset_residual_callback
(
self
):
def
reset_
pulid_
residual_callback
(
self
):
self
.
callback_holder
=
None
self
.
callback_holder
=
None
self
.
m
.
set_residual_callback
(
None
)
self
.
m
.
set_residual_callback
(
None
)
...
...
nunchaku/pipeline/pipeline_flux_pulid.py
View file @
07f07563
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
import
gc
import
gc
import
logging
import
os
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
cv2
import
cv2
...
@@ -13,9 +16,9 @@ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
...
@@ -13,9 +16,9 @@ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from
diffusers.utils
import
replace_example_docstring
from
diffusers.utils
import
replace_example_docstring
from
facexlib.parsing
import
init_parsing_model
from
facexlib.parsing
import
init_parsing_model
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
huggingface_hub
import
hf_hub_download
,
snapshot_download
from
huggingface_hub
import
snapshot_download
from
huggingface_hub.constants
import
HUGGINGFACE_HUB_CACHE
from
insightface.app
import
FaceAnalysis
from
insightface.app
import
FaceAnalysis
from
safetensors.torch
import
load_file
from
torch
import
nn
from
torch
import
nn
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms.functional
import
normalize
,
resize
from
torchvision.transforms.functional
import
normalize
,
resize
...
@@ -24,10 +27,54 @@ from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA
...
@@ -24,10 +27,54 @@ from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA
from
..models.pulid.eva_clip
import
create_model_and_transforms
from
..models.pulid.eva_clip
import
create_model_and_transforms
from
..models.pulid.eva_clip.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
..models.pulid.eva_clip.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
..models.pulid.utils
import
img2tensor
,
resize_numpy_image_long
,
tensor2img
from
..models.pulid.utils
import
img2tensor
,
resize_numpy_image_long
,
tensor2img
from
..models.transformers
import
NunchakuFluxTransformer2dModel
from
..utils
import
load_state_dict_in_safetensors
,
sha256sum
# Get log level from environment variable (default to INFO)
log_level
=
os
.
getenv
(
"LOG_LEVEL"
,
"INFO"
).
upper
()
# Configure logging
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
,
logging
.
INFO
),
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
def
check_antelopev2_dir
(
antelopev2_dirpath
:
str
|
os
.
PathLike
[
str
])
->
bool
:
antelopev2_dirpath
=
Path
(
antelopev2_dirpath
)
required_files
=
{
"1k3d68.onnx"
:
"df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc"
,
"2d106det.onnx"
:
"f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf"
,
"genderage.onnx"
:
"4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb"
,
"glintr100.onnx"
:
"4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf"
,
"scrfd_10g_bnkps.onnx"
:
"5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91"
,
}
if
not
antelopev2_dirpath
.
is_dir
():
logger
.
debug
(
f
"Directory does not exist:
{
antelopev2_dirpath
}
"
)
return
False
for
filename
,
expected_hash
in
required_files
.
items
():
filepath
=
antelopev2_dirpath
/
filename
if
not
filepath
.
exists
():
logger
.
debug
(
f
"Missing file:
{
filename
}
"
)
return
False
if
expected_hash
!=
"<SKIP_HASH>"
and
not
sha256sum
(
filepath
)
==
expected_hash
:
logger
.
debug
(
f
"Hash mismatch for:
{
filename
}
"
)
return
False
return
True
class
PuLIDPipeline
(
nn
.
Module
):
class
PuLIDPipeline
(
nn
.
Module
):
def
__init__
(
self
,
dit
,
device
,
weight_dtype
=
torch
.
bfloat16
,
onnx_provider
=
"gpu"
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
dit
:
NunchakuFluxTransformer2dModel
,
device
:
str
|
torch
.
device
,
weight_dtype
:
str
|
torch
.
dtype
=
torch
.
bfloat16
,
onnx_provider
:
str
=
"gpu"
,
pulid_path
:
str
|
os
.
PathLike
[
str
]
=
"guozinan/PuLID/pulid_flux_v0.9.1.safetensors"
,
eva_clip_path
:
str
|
os
.
PathLike
[
str
]
=
"QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"
,
insightface_dirpath
:
str
|
os
.
PathLike
[
str
]
|
None
=
None
,
facexlib_dirpath
:
str
|
os
.
PathLike
[
str
]
|
None
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
device
=
device
self
.
device
=
device
self
.
weight_dtype
=
weight_dtype
self
.
weight_dtype
=
weight_dtype
...
@@ -50,6 +97,11 @@ class PuLIDPipeline(nn.Module):
...
@@ -50,6 +97,11 @@ class PuLIDPipeline(nn.Module):
# preprocessors
# preprocessors
# face align and parsing
# face align and parsing
if
facexlib_dirpath
is
None
:
facexlib_dirpath
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
"facexlib"
facexlib_dirpath
=
Path
(
facexlib_dirpath
)
self
.
face_helper
=
FaceRestoreHelper
(
self
.
face_helper
=
FaceRestoreHelper
(
upscale_factor
=
1
,
upscale_factor
=
1
,
face_size
=
512
,
face_size
=
512
,
...
@@ -57,11 +109,17 @@ class PuLIDPipeline(nn.Module):
...
@@ -57,11 +109,17 @@ class PuLIDPipeline(nn.Module):
det_model
=
"retinaface_resnet50"
,
det_model
=
"retinaface_resnet50"
,
save_ext
=
"png"
,
save_ext
=
"png"
,
device
=
self
.
device
,
device
=
self
.
device
,
model_rootpath
=
str
(
facexlib_dirpath
),
)
)
self
.
face_helper
.
face_parse
=
None
self
.
face_helper
.
face_parse
=
None
self
.
face_helper
.
face_parse
=
init_parsing_model
(
model_name
=
"bisenet"
,
device
=
self
.
device
)
self
.
face_helper
.
face_parse
=
init_parsing_model
(
model_name
=
"bisenet"
,
device
=
self
.
device
,
model_rootpath
=
str
(
facexlib_dirpath
)
)
# clip-vit backbone
# clip-vit backbone
model
,
_
,
_
=
create_model_and_transforms
(
"EVA02-CLIP-L-14-336"
,
"eva_clip"
,
force_custom_clip
=
True
)
model
,
_
,
_
=
create_model_and_transforms
(
"EVA02-CLIP-L-14-336"
,
"eva_clip"
,
force_custom_clip
=
True
,
pretrained_path
=
eva_clip_path
)
model
=
model
.
visual
model
=
model
.
visual
self
.
clip_vision_model
=
model
.
to
(
self
.
device
,
dtype
=
self
.
weight_dtype
)
self
.
clip_vision_model
=
model
.
to
(
self
.
device
,
dtype
=
self
.
weight_dtype
)
eva_transform_mean
=
getattr
(
self
.
clip_vision_model
,
"image_mean"
,
OPENAI_DATASET_MEAN
)
eva_transform_mean
=
getattr
(
self
.
clip_vision_model
,
"image_mean"
,
OPENAI_DATASET_MEAN
)
...
@@ -72,41 +130,51 @@ class PuLIDPipeline(nn.Module):
...
@@ -72,41 +130,51 @@ class PuLIDPipeline(nn.Module):
eva_transform_std
=
(
eva_transform_std
,)
*
3
eva_transform_std
=
(
eva_transform_std
,)
*
3
self
.
eva_transform_mean
=
eva_transform_mean
self
.
eva_transform_mean
=
eva_transform_mean
self
.
eva_transform_std
=
eva_transform_std
self
.
eva_transform_std
=
eva_transform_std
# antelopev2
# antelopev2
snapshot_download
(
"DIAMONIK7777/antelopev2"
,
local_dir
=
"models/antelopev2"
)
if
insightface_dirpath
is
None
:
insightface_dirpath
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
"insightface"
insightface_dirpath
=
Path
(
insightface_dirpath
)
if
insightface_dirpath
is
not
None
:
antelopev2_dirpath
=
insightface_dirpath
/
"models"
/
"antelopev2"
else
:
antelopev2_dirpath
=
None
if
antelopev2_dirpath
is
None
or
not
check_antelopev2_dir
(
antelopev2_dirpath
):
snapshot_download
(
"DIAMONIK7777/antelopev2"
,
local_dir
=
antelopev2_dirpath
)
providers
=
(
providers
=
(
[
"CPUExecutionProvider"
]
if
onnx_provider
==
"cpu"
else
[
"CUDAExecutionProvider"
,
"CPUExecutionProvider"
]
[
"CPUExecutionProvider"
]
if
onnx_provider
==
"cpu"
else
[
"CUDAExecutionProvider"
,
"CPUExecutionProvider"
]
)
)
self
.
app
=
FaceAnalysis
(
name
=
"antelopev2"
,
root
=
"."
,
providers
=
providers
)
self
.
app
=
FaceAnalysis
(
name
=
"antelopev2"
,
root
=
insightface_dirpath
,
providers
=
providers
)
self
.
app
.
prepare
(
ctx_id
=
0
,
det_size
=
(
640
,
640
))
self
.
app
.
prepare
(
ctx_id
=
0
,
det_size
=
(
640
,
640
))
self
.
handler_ante
=
insightface
.
model_zoo
.
get_model
(
"models/antelopev2/glintr100.onnx"
,
providers
=
providers
)
self
.
handler_ante
=
insightface
.
model_zoo
.
get_model
(
str
(
antelopev2_dirpath
/
"glintr100.onnx"
),
providers
=
providers
)
self
.
handler_ante
.
prepare
(
ctx_id
=
0
)
self
.
handler_ante
.
prepare
(
ctx_id
=
0
)
gc
.
collect
()
# pulid model
torch
.
cuda
.
empty_cache
()
state_dict
=
load_state_dict_in_safetensors
(
pulid_path
)
module_state_dict
=
{}
# other configs
self
.
debug_img_list
=
[]
def
load_pretrain
(
self
,
pretrain_path
=
None
,
version
=
"v0.9.0"
):
hf_hub_download
(
"guozinan/PuLID"
,
f
"pulid_flux_
{
version
}
.safetensors"
,
local_dir
=
"models"
)
ckpt_path
=
f
"models/pulid_flux_
{
version
}
.safetensors"
if
pretrain_path
is
not
None
:
ckpt_path
=
pretrain_path
state_dict
=
load_file
(
ckpt_path
)
state_dict_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
for
k
,
v
in
state_dict
.
items
():
module
=
k
.
split
(
"."
)[
0
]
module
=
k
.
split
(
"."
)[
0
]
state
_dict
_dict
.
setdefault
(
module
,
{})
module_
state_dict
.
setdefault
(
module
,
{})
new_k
=
k
[
len
(
module
)
+
1
:]
new_k
=
k
[
len
(
module
)
+
1
:]
state
_dict
_dict
[
module
][
new_k
]
=
v
module_
state_dict
[
module
][
new_k
]
=
v
for
module
in
state
_dict
_dict
:
for
module
in
module_
state_dict
:
print
(
f
"loading from
{
module
}
"
)
logging
.
debug
(
f
"loading from
{
module
}
"
)
getattr
(
self
,
module
).
load_state_dict
(
state
_dict
_dict
[
module
],
strict
=
True
)
getattr
(
self
,
module
).
load_state_dict
(
module_
state_dict
[
module
],
strict
=
True
)
del
state_dict
del
state_dict
del
state_dict_dict
del
module_state_dict
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# other configs
self
.
debug_img_list
=
[]
def
to_gray
(
self
,
img
):
def
to_gray
(
self
,
img
):
x
=
0.299
*
img
[:,
0
:
1
]
+
0.587
*
img
[:,
1
:
2
]
+
0.114
*
img
[:,
2
:
3
]
x
=
0.299
*
img
[:,
0
:
1
]
+
0.587
*
img
[:,
1
:
2
]
+
0.114
*
img
[:,
2
:
3
]
...
@@ -206,7 +274,6 @@ class PuLIDFluxPipeline(FluxPipeline):
...
@@ -206,7 +274,6 @@ class PuLIDFluxPipeline(FluxPipeline):
pulid_device
=
"cuda"
,
pulid_device
=
"cuda"
,
weight_dtype
=
torch
.
bfloat16
,
weight_dtype
=
torch
.
bfloat16
,
onnx_provider
=
"gpu"
,
onnx_provider
=
"gpu"
,
pretrained_model
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
scheduler
=
scheduler
,
scheduler
=
scheduler
,
...
@@ -232,7 +299,6 @@ class PuLIDFluxPipeline(FluxPipeline):
...
@@ -232,7 +299,6 @@ class PuLIDFluxPipeline(FluxPipeline):
weight_dtype
=
self
.
weight_dtype
,
weight_dtype
=
self
.
weight_dtype
,
onnx_provider
=
self
.
onnx_provider
,
onnx_provider
=
self
.
onnx_provider
,
)
)
self
.
pulid_model
.
load_pretrain
(
pretrained_model
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
...
...
nunchaku/utils.py
View file @
07f07563
import
hashlib
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -7,6 +8,14 @@ import torch
...
@@ -7,6 +8,14 @@ import torch
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
def
sha256sum
(
filepath
:
str
|
os
.
PathLike
[
str
])
->
str
:
sha256
=
hashlib
.
sha256
()
with
open
(
filepath
,
"rb"
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
8192
),
b
""
):
sha256
.
update
(
chunk
)
return
sha256
.
hexdigest
()
def
fetch_or_download
(
path
:
str
|
Path
,
repo_type
:
str
=
"model"
)
->
Path
:
def
fetch_or_download
(
path
:
str
|
Path
,
repo_type
:
str
=
"model"
)
->
Path
:
path
=
Path
(
path
)
path
=
Path
(
path
)
...
...
src/FluxModel.cpp
View file @
07f07563
...
@@ -837,11 +837,8 @@ Tensor FluxModel::forward(Tensor hidden_states,
...
@@ -837,11 +837,8 @@ Tensor FluxModel::forward(Tensor hidden_states,
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
}
if
(
residual_callback
&&
layer
%
2
==
0
)
{
if
(
residual_callback
&&
layer
%
2
==
0
)
{
Tensor
cpu_input
=
hidden_states
.
copy
(
Device
::
cpu
());
Tensor
residual
=
residual_callback
(
hidden_states
);
pybind11
::
gil_scoped_acquire
gil
;
hidden_states
=
kernels
::
add
(
hidden_states
,
residual
);
Tensor
cpu_output
=
residual_callback
(
cpu_input
);
Tensor
residual
=
cpu_output
.
copy
(
Device
::
cuda
());
hidden_states
=
kernels
::
add
(
hidden_states
,
residual
);
}
}
}
else
{
}
else
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
...
@@ -875,12 +872,9 @@ Tensor FluxModel::forward(Tensor hidden_states,
...
@@ -875,12 +872,9 @@ Tensor FluxModel::forward(Tensor hidden_states,
size_t
local_layer_idx
=
layer
-
transformer_blocks
.
size
();
size_t
local_layer_idx
=
layer
-
transformer_blocks
.
size
();
if
(
residual_callback
&&
local_layer_idx
%
4
==
0
)
{
if
(
residual_callback
&&
local_layer_idx
%
4
==
0
)
{
Tensor
callback_input
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
Tensor
callback_input
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
Tensor
cpu_input
=
callback_input
.
copy
(
Device
::
cpu
());
Tensor
residual
=
residual_callback
(
callback_input
);
pybind11
::
gil_scoped_acquire
gil
;
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
Tensor
cpu_output
=
residual_callback
(
cpu_input
);
slice
=
kernels
::
add
(
slice
,
residual
);
Tensor
residual
=
cpu_output
.
copy
(
Device
::
cuda
());
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
residual
);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
}
}
...
@@ -919,6 +913,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
...
@@ -919,6 +913,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
Tensor
controlnet_block_samples
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
Tensor
controlnet_single_block_samples
)
{
if
(
offload
&&
layer
>
0
)
{
if
(
layer
<
transformer_blocks
.
size
())
{
transformer_blocks
.
at
(
layer
)
->
loadLazyParams
();
}
else
{
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
loadLazyParams
();
}
}
if
(
layer
<
transformer_blocks
.
size
())
{
if
(
layer
<
transformer_blocks
.
size
())
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
...
@@ -954,6 +956,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
...
@@ -954,6 +956,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
if
(
offload
&&
layer
>
0
)
{
if
(
layer
<
transformer_blocks
.
size
())
{
transformer_blocks
.
at
(
layer
)
->
releaseLazyParams
();
}
else
{
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
releaseLazyParams
();
}
}
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
}
...
...
src/FluxModel.h
View file @
07f07563
...
@@ -189,6 +189,9 @@ public:
...
@@ -189,6 +189,9 @@ public:
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
bool
isOffloadEnabled
()
const
{
return
offload
;
}
private:
private:
bool
offload
;
bool
offload
;
...
...
tests/flux/test_flux_dev_loras.py
View file @
07f07563
import
pytest
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.utils
import
get_precision
,
is_turing
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
from
.utils
import
run_test
...
@@ -54,3 +57,38 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
...
@@ -54,3 +57,38 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
0.310
if
get_precision
()
==
"int4"
else
0.168
,
expected_lpips
=
0.310
if
get_precision
()
==
"int4"
else
0.168
,
)
)
def
test_kohya_lora
():
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
transformer
.
update_lora_params
(
"mit-han-lab/nunchaku-test-models/hand_drawn_game.safetensors"
)
transformer
.
set_lora_strength
(
1
)
prompt
=
(
"masterful impressionism oil painting titled 'the violinist', the composition follows the rule of thirds, "
"placing the violinist centrally in the frame. the subject is a young woman with fair skin and light blonde "
"hair is styled in a long, flowing hairstyle with natural waves. she is dressed in an opulent, "
"luxurious silver silk gown with a high waist and intricate gold detailing along the bodice. "
"the gown's texture is smooth and reflective. she holds a violin under her chin, "
"her right hand poised to play, and her left hand supporting the neck of the instrument. "
"she wears a delicate gold necklace with small, sparkling gemstones that catch the light. "
"her beautiful eyes focused on the viewer. the background features an elegantly furnished room "
"with classical late 19th century decor. to the left, there is a large, ornate portrait of "
"a man in a dark suit, set in a gilded frame. below this, a wooden desk with a closed book. "
"to the right, a red upholstered chair with a wooden frame is partially visible. "
"the room is bathed in natural light streaming through a window with red curtains, "
"creating a warm, inviting atmosphere. the lighting highlights the violinist, "
"casting soft shadows that enhance the depth and realism of the scene, highly aesthetic, "
"harmonious colors, impressioniststrokes, "
"<lora:style-impressionist_strokes-flux-by_daalis:1.0> <lora:image_upgrade-flux-by_zeronwo7829:1.0>"
)
image
=
pipeline
(
prompt
,
num_inference_steps
=
20
,
guidance_scale
=
3.5
).
images
[
0
]
image
.
save
(
f
"flux.1-dev-
{
precision
}
-1.png"
)
tests/flux/test_flux_dev_pulid.py
View file @
07f07563
import
gc
from
types
import
MethodType
from
types
import
MethodType
import
numpy
as
np
import
numpy
as
np
...
@@ -15,6 +16,8 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -15,6 +16,8 @@ from nunchaku.utils import get_precision, is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
def
test_flux_dev_pulid
():
def
test_flux_dev_pulid
():
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment