Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
07f07563
Unverified
Commit
07f07563
authored
Jun 16, 2025
by
Muyang Li
Committed by
GitHub
Jun 16, 2025
Browse files
chore: release v0.3.1
parents
7214300d
ad92b16a
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
226 additions
and
52 deletions
+226
-52
nunchaku/lora/flux/nunchaku_converter.py
nunchaku/lora/flux/nunchaku_converter.py
+6
-0
nunchaku/models/pulid/eva_clip/factory.py
nunchaku/models/pulid/eva_clip/factory.py
+35
-8
nunchaku/models/pulid/pulid_forward.py
nunchaku/models/pulid/pulid_forward.py
+12
-0
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+5
-5
nunchaku/pipeline/pipeline_flux_pulid.py
nunchaku/pipeline/pipeline_flux_pulid.py
+94
-28
nunchaku/utils.py
nunchaku/utils.py
+9
-0
src/FluxModel.cpp
src/FluxModel.cpp
+21
-11
src/FluxModel.h
src/FluxModel.h
+3
-0
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+38
-0
tests/flux/test_flux_dev_pulid.py
tests/flux/test_flux_dev_pulid.py
+3
-0
No files found.
nunchaku/lora/flux/nunchaku_converter.py
View file @
07f07563
...
...
@@ -12,8 +12,14 @@ from .diffusers_converter import to_diffusers
from
.packer
import
NunchakuWeightPacker
from
.utils
import
is_nunchaku_format
,
pad
# Get log level from environment variable (default to INFO)
log_level
=
os
.
getenv
(
"LOG_LEVEL"
,
"INFO"
).
upper
()
# Configure logging
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
,
logging
.
INFO
),
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
# region utilities
...
...
nunchaku/models/pulid/eva_clip/factory.py
View file @
07f07563
...
...
@@ -3,11 +3,13 @@ import logging
import
os
import
re
from
copy
import
deepcopy
from
os
import
PathLike
from
pathlib
import
Path
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
....utils
import
fetch_or_download
from
.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
.model
import
CLIP
,
CustomCLIP
,
convert_to_custom_text_state_dict
,
get_cast_dtype
from
.pretrained
import
download_pretrained
,
get_pretrained_cfg
,
list_pretrained_tags_by_model
...
...
@@ -227,6 +229,7 @@ def create_model(
pretrained_text_model
:
str
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
skip_list
:
list
=
[],
pretrained_path
:
str
|
PathLike
[
str
]
=
"QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"
,
):
model_name
=
model_name
.
replace
(
"/"
,
"-"
)
# for callers using old naming with / in ViT names
if
isinstance
(
device
,
str
):
...
...
@@ -239,8 +242,35 @@ def create_model(
if
model_cfg
is
not
None
:
logging
.
info
(
f
"Loaded
{
model_name
}
model config."
)
else
:
logging
.
error
(
f
"Model config for
{
model_name
}
not found; available models
{
list_models
()
}
."
)
raise
RuntimeError
(
f
"Model config for
{
model_name
}
not found."
)
model_cfg
=
{
"embed_dim"
:
768
,
"vision_cfg"
:
{
"image_size"
:
336
,
"layers"
:
24
,
"width"
:
1024
,
"drop_path_rate"
:
0
,
"head_width"
:
64
,
"mlp_ratio"
:
2.6667
,
"patch_size"
:
14
,
"eva_model_name"
:
"eva-clip-l-14-336"
,
"xattn"
:
True
,
"fusedLN"
:
True
,
"rope"
:
True
,
"pt_hw_seq_len"
:
16
,
"intp_freq"
:
True
,
"naiveswiglu"
:
True
,
"subln"
:
True
,
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
768
,
"heads"
:
12
,
"layers"
:
12
,
"xattn"
:
False
,
"fusedLN"
:
True
,
},
}
if
"rope"
in
model_cfg
.
get
(
"vision_cfg"
,
{}):
if
model_cfg
[
"vision_cfg"
][
"rope"
]:
...
...
@@ -270,12 +300,7 @@ def create_model(
pretrained_cfg
=
{}
if
pretrained
:
checkpoint_path
=
""
pretrained_cfg
=
get_pretrained_cfg
(
model_name
,
pretrained
)
if
pretrained_cfg
:
checkpoint_path
=
download_pretrained
(
pretrained_cfg
,
cache_dir
=
cache_dir
)
elif
os
.
path
.
exists
(
pretrained
):
checkpoint_path
=
pretrained
checkpoint_path
=
fetch_or_download
(
pretrained_path
)
if
checkpoint_path
:
logging
.
info
(
f
"Loading pretrained
{
model_name
}
weights (
{
pretrained
}
)."
)
...
...
@@ -379,6 +404,7 @@ def create_model_and_transforms(
image_std
:
Optional
[
Tuple
[
float
,
...]]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
skip_list
:
list
=
[],
pretrained_path
:
str
|
PathLike
[
str
]
=
"QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"
,
):
model
=
create_model
(
model_name
,
...
...
@@ -396,6 +422,7 @@ def create_model_and_transforms(
pretrained_text_model
=
pretrained_text_model
,
cache_dir
=
cache_dir
,
skip_list
=
skip_list
,
pretrained_path
=
pretrained_path
,
)
image_mean
=
image_mean
or
getattr
(
model
.
visual
,
"image_mean"
,
None
)
...
...
nunchaku/models/pulid/pulid_forward.py
View file @
07f07563
...
...
@@ -24,6 +24,8 @@ def pulid_forward(
controlnet_single_block_samples
=
None
,
return_dict
:
bool
=
True
,
controlnet_blocks_repeat
:
bool
=
False
,
start_timestep
:
float
|
None
=
None
,
end_timestep
:
float
|
None
=
None
,
)
->
Union
[
torch
.
FloatTensor
,
Transformer2DModelOutput
]:
"""
Copied from diffusers.models.flux.transformer_flux.py
...
...
@@ -53,6 +55,16 @@ def pulid_forward(
"""
hidden_states
=
self
.
x_embedder
(
hidden_states
)
if
timestep
.
numel
()
>
1
:
timestep_float
=
timestep
.
flatten
()[
0
].
item
()
else
:
timestep_float
=
timestep
.
item
()
if
start_timestep
is
not
None
and
start_timestep
>
timestep_float
:
id_embeddings
=
None
if
end_timestep
is
not
None
and
end_timestep
<
timestep_float
:
id_embeddings
=
None
timestep
=
timestep
.
to
(
hidden_states
.
dtype
)
*
1000
if
guidance
is
not
None
:
guidance
=
guidance
.
to
(
hidden_states
.
dtype
)
*
1000
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
07f07563
...
...
@@ -81,7 +81,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self
.
id_weight
=
id_weight
self
.
pulid_ca_idx
=
0
if
self
.
id_embeddings
is
not
None
:
self
.
set_residual_callback
()
self
.
set_
pulid_
residual_callback
()
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
...
...
@@ -129,7 +129,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
)
if
self
.
id_embeddings
is
not
None
:
self
.
reset_residual_callback
()
self
.
reset_
pulid_
residual_callback
()
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
...
...
@@ -194,21 +194,21 @@ class NunchakuFluxTransformerBlocks(nn.Module):
return
encoder_hidden_states
,
hidden_states
def
set_residual_callback
(
self
):
def
set_
pulid_
residual_callback
(
self
):
id_embeddings
=
self
.
id_embeddings
pulid_ca
=
self
.
pulid_ca
pulid_ca_idx
=
[
self
.
pulid_ca_idx
]
id_weight
=
self
.
id_weight
def
callback
(
hidden_states
):
ip
=
id_weight
*
pulid_ca
[
pulid_ca_idx
[
0
]](
id_embeddings
,
hidden_states
.
to
(
"cuda"
)
)
ip
=
id_weight
*
pulid_ca
[
pulid_ca_idx
[
0
]](
id_embeddings
,
hidden_states
)
pulid_ca_idx
[
0
]
+=
1
return
ip
self
.
callback_holder
=
callback
self
.
m
.
set_residual_callback
(
callback
)
def
reset_residual_callback
(
self
):
def
reset_
pulid_
residual_callback
(
self
):
self
.
callback_holder
=
None
self
.
m
.
set_residual_callback
(
None
)
...
...
nunchaku/pipeline/pipeline_flux_pulid.py
View file @
07f07563
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
import
gc
import
logging
import
os
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
cv2
...
...
@@ -13,9 +16,9 @@ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from
diffusers.utils
import
replace_example_docstring
from
facexlib.parsing
import
init_parsing_model
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
huggingface_hub
import
hf_hub_download
,
snapshot_download
from
huggingface_hub
import
snapshot_download
from
huggingface_hub.constants
import
HUGGINGFACE_HUB_CACHE
from
insightface.app
import
FaceAnalysis
from
safetensors.torch
import
load_file
from
torch
import
nn
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms.functional
import
normalize
,
resize
...
...
@@ -24,10 +27,54 @@ from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA
from
..models.pulid.eva_clip
import
create_model_and_transforms
from
..models.pulid.eva_clip.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
..models.pulid.utils
import
img2tensor
,
resize_numpy_image_long
,
tensor2img
from
..models.transformers
import
NunchakuFluxTransformer2dModel
from
..utils
import
load_state_dict_in_safetensors
,
sha256sum
# Get log level from environment variable (default to INFO)
log_level
=
os
.
getenv
(
"LOG_LEVEL"
,
"INFO"
).
upper
()
# Configure logging
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
,
logging
.
INFO
),
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
def
check_antelopev2_dir
(
antelopev2_dirpath
:
str
|
os
.
PathLike
[
str
])
->
bool
:
antelopev2_dirpath
=
Path
(
antelopev2_dirpath
)
required_files
=
{
"1k3d68.onnx"
:
"df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc"
,
"2d106det.onnx"
:
"f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf"
,
"genderage.onnx"
:
"4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb"
,
"glintr100.onnx"
:
"4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf"
,
"scrfd_10g_bnkps.onnx"
:
"5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91"
,
}
if
not
antelopev2_dirpath
.
is_dir
():
logger
.
debug
(
f
"Directory does not exist:
{
antelopev2_dirpath
}
"
)
return
False
for
filename
,
expected_hash
in
required_files
.
items
():
filepath
=
antelopev2_dirpath
/
filename
if
not
filepath
.
exists
():
logger
.
debug
(
f
"Missing file:
{
filename
}
"
)
return
False
if
expected_hash
!=
"<SKIP_HASH>"
and
not
sha256sum
(
filepath
)
==
expected_hash
:
logger
.
debug
(
f
"Hash mismatch for:
{
filename
}
"
)
return
False
return
True
class
PuLIDPipeline
(
nn
.
Module
):
def
__init__
(
self
,
dit
,
device
,
weight_dtype
=
torch
.
bfloat16
,
onnx_provider
=
"gpu"
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
dit
:
NunchakuFluxTransformer2dModel
,
device
:
str
|
torch
.
device
,
weight_dtype
:
str
|
torch
.
dtype
=
torch
.
bfloat16
,
onnx_provider
:
str
=
"gpu"
,
pulid_path
:
str
|
os
.
PathLike
[
str
]
=
"guozinan/PuLID/pulid_flux_v0.9.1.safetensors"
,
eva_clip_path
:
str
|
os
.
PathLike
[
str
]
=
"QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"
,
insightface_dirpath
:
str
|
os
.
PathLike
[
str
]
|
None
=
None
,
facexlib_dirpath
:
str
|
os
.
PathLike
[
str
]
|
None
=
None
,
):
super
().
__init__
()
self
.
device
=
device
self
.
weight_dtype
=
weight_dtype
...
...
@@ -50,6 +97,11 @@ class PuLIDPipeline(nn.Module):
# preprocessors
# face align and parsing
if
facexlib_dirpath
is
None
:
facexlib_dirpath
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
"facexlib"
facexlib_dirpath
=
Path
(
facexlib_dirpath
)
self
.
face_helper
=
FaceRestoreHelper
(
upscale_factor
=
1
,
face_size
=
512
,
...
...
@@ -57,11 +109,17 @@ class PuLIDPipeline(nn.Module):
det_model
=
"retinaface_resnet50"
,
save_ext
=
"png"
,
device
=
self
.
device
,
model_rootpath
=
str
(
facexlib_dirpath
),
)
self
.
face_helper
.
face_parse
=
None
self
.
face_helper
.
face_parse
=
init_parsing_model
(
model_name
=
"bisenet"
,
device
=
self
.
device
)
self
.
face_helper
.
face_parse
=
init_parsing_model
(
model_name
=
"bisenet"
,
device
=
self
.
device
,
model_rootpath
=
str
(
facexlib_dirpath
)
)
# clip-vit backbone
model
,
_
,
_
=
create_model_and_transforms
(
"EVA02-CLIP-L-14-336"
,
"eva_clip"
,
force_custom_clip
=
True
)
model
,
_
,
_
=
create_model_and_transforms
(
"EVA02-CLIP-L-14-336"
,
"eva_clip"
,
force_custom_clip
=
True
,
pretrained_path
=
eva_clip_path
)
model
=
model
.
visual
self
.
clip_vision_model
=
model
.
to
(
self
.
device
,
dtype
=
self
.
weight_dtype
)
eva_transform_mean
=
getattr
(
self
.
clip_vision_model
,
"image_mean"
,
OPENAI_DATASET_MEAN
)
...
...
@@ -72,41 +130,51 @@ class PuLIDPipeline(nn.Module):
eva_transform_std
=
(
eva_transform_std
,)
*
3
self
.
eva_transform_mean
=
eva_transform_mean
self
.
eva_transform_std
=
eva_transform_std
# antelopev2
snapshot_download
(
"DIAMONIK7777/antelopev2"
,
local_dir
=
"models/antelopev2"
)
if
insightface_dirpath
is
None
:
insightface_dirpath
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
"insightface"
insightface_dirpath
=
Path
(
insightface_dirpath
)
if
insightface_dirpath
is
not
None
:
antelopev2_dirpath
=
insightface_dirpath
/
"models"
/
"antelopev2"
else
:
antelopev2_dirpath
=
None
if
antelopev2_dirpath
is
None
or
not
check_antelopev2_dir
(
antelopev2_dirpath
):
snapshot_download
(
"DIAMONIK7777/antelopev2"
,
local_dir
=
antelopev2_dirpath
)
providers
=
(
[
"CPUExecutionProvider"
]
if
onnx_provider
==
"cpu"
else
[
"CUDAExecutionProvider"
,
"CPUExecutionProvider"
]
)
self
.
app
=
FaceAnalysis
(
name
=
"antelopev2"
,
root
=
"."
,
providers
=
providers
)
self
.
app
=
FaceAnalysis
(
name
=
"antelopev2"
,
root
=
insightface_dirpath
,
providers
=
providers
)
self
.
app
.
prepare
(
ctx_id
=
0
,
det_size
=
(
640
,
640
))
self
.
handler_ante
=
insightface
.
model_zoo
.
get_model
(
"models/antelopev2/glintr100.onnx"
,
providers
=
providers
)
self
.
handler_ante
=
insightface
.
model_zoo
.
get_model
(
str
(
antelopev2_dirpath
/
"glintr100.onnx"
),
providers
=
providers
)
self
.
handler_ante
.
prepare
(
ctx_id
=
0
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# other configs
self
.
debug_img_list
=
[]
# pulid model
state_dict
=
load_state_dict_in_safetensors
(
pulid_path
)
module_state_dict
=
{}
def
load_pretrain
(
self
,
pretrain_path
=
None
,
version
=
"v0.9.0"
):
hf_hub_download
(
"guozinan/PuLID"
,
f
"pulid_flux_
{
version
}
.safetensors"
,
local_dir
=
"models"
)
ckpt_path
=
f
"models/pulid_flux_
{
version
}
.safetensors"
if
pretrain_path
is
not
None
:
ckpt_path
=
pretrain_path
state_dict
=
load_file
(
ckpt_path
)
state_dict_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
module
=
k
.
split
(
"."
)[
0
]
state
_dict
_dict
.
setdefault
(
module
,
{})
module_
state_dict
.
setdefault
(
module
,
{})
new_k
=
k
[
len
(
module
)
+
1
:]
state
_dict
_dict
[
module
][
new_k
]
=
v
module_
state_dict
[
module
][
new_k
]
=
v
for
module
in
state
_dict
_dict
:
print
(
f
"loading from
{
module
}
"
)
getattr
(
self
,
module
).
load_state_dict
(
state
_dict
_dict
[
module
],
strict
=
True
)
for
module
in
module_
state_dict
:
logging
.
debug
(
f
"loading from
{
module
}
"
)
getattr
(
self
,
module
).
load_state_dict
(
module_
state_dict
[
module
],
strict
=
True
)
del
state_dict
del
state_dict_dict
del
module_state_dict
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# other configs
self
.
debug_img_list
=
[]
def
to_gray
(
self
,
img
):
x
=
0.299
*
img
[:,
0
:
1
]
+
0.587
*
img
[:,
1
:
2
]
+
0.114
*
img
[:,
2
:
3
]
...
...
@@ -206,7 +274,6 @@ class PuLIDFluxPipeline(FluxPipeline):
pulid_device
=
"cuda"
,
weight_dtype
=
torch
.
bfloat16
,
onnx_provider
=
"gpu"
,
pretrained_model
=
None
,
):
super
().
__init__
(
scheduler
=
scheduler
,
...
...
@@ -232,7 +299,6 @@ class PuLIDFluxPipeline(FluxPipeline):
weight_dtype
=
self
.
weight_dtype
,
onnx_provider
=
self
.
onnx_provider
,
)
self
.
pulid_model
.
load_pretrain
(
pretrained_model
)
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
...
...
nunchaku/utils.py
View file @
07f07563
import
hashlib
import
os
import
warnings
from
pathlib
import
Path
...
...
@@ -7,6 +8,14 @@ import torch
from
huggingface_hub
import
hf_hub_download
def
sha256sum
(
filepath
:
str
|
os
.
PathLike
[
str
])
->
str
:
sha256
=
hashlib
.
sha256
()
with
open
(
filepath
,
"rb"
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
8192
),
b
""
):
sha256
.
update
(
chunk
)
return
sha256
.
hexdigest
()
def
fetch_or_download
(
path
:
str
|
Path
,
repo_type
:
str
=
"model"
)
->
Path
:
path
=
Path
(
path
)
...
...
src/FluxModel.cpp
View file @
07f07563
...
...
@@ -837,11 +837,8 @@ Tensor FluxModel::forward(Tensor hidden_states,
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
if
(
residual_callback
&&
layer
%
2
==
0
)
{
Tensor
cpu_input
=
hidden_states
.
copy
(
Device
::
cpu
());
pybind11
::
gil_scoped_acquire
gil
;
Tensor
cpu_output
=
residual_callback
(
cpu_input
);
Tensor
residual
=
cpu_output
.
copy
(
Device
::
cuda
());
hidden_states
=
kernels
::
add
(
hidden_states
,
residual
);
Tensor
residual
=
residual_callback
(
hidden_states
);
hidden_states
=
kernels
::
add
(
hidden_states
,
residual
);
}
}
else
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
...
...
@@ -875,12 +872,9 @@ Tensor FluxModel::forward(Tensor hidden_states,
size_t
local_layer_idx
=
layer
-
transformer_blocks
.
size
();
if
(
residual_callback
&&
local_layer_idx
%
4
==
0
)
{
Tensor
callback_input
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
Tensor
cpu_input
=
callback_input
.
copy
(
Device
::
cpu
());
pybind11
::
gil_scoped_acquire
gil
;
Tensor
cpu_output
=
residual_callback
(
cpu_input
);
Tensor
residual
=
cpu_output
.
copy
(
Device
::
cuda
());
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
residual
);
Tensor
residual
=
residual_callback
(
callback_input
);
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
residual
);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
...
...
@@ -919,6 +913,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
if
(
offload
&&
layer
>
0
)
{
if
(
layer
<
transformer_blocks
.
size
())
{
transformer_blocks
.
at
(
layer
)
->
loadLazyParams
();
}
else
{
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
loadLazyParams
();
}
}
if
(
layer
<
transformer_blocks
.
size
())
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
...
...
@@ -954,6 +956,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
if
(
offload
&&
layer
>
0
)
{
if
(
layer
<
transformer_blocks
.
size
())
{
transformer_blocks
.
at
(
layer
)
->
releaseLazyParams
();
}
else
{
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
releaseLazyParams
();
}
}
return
{
hidden_states
,
encoder_hidden_states
};
}
...
...
src/FluxModel.h
View file @
07f07563
...
...
@@ -189,6 +189,9 @@ public:
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
bool
isOffloadEnabled
()
const
{
return
offload
;
}
private:
bool
offload
;
...
...
tests/flux/test_flux_dev_loras.py
View file @
07f07563
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
...
...
@@ -54,3 +57,38 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
cache_threshold
=
0
,
expected_lpips
=
0.310
if
get_precision
()
==
"int4"
else
0.168
,
)
def
test_kohya_lora
():
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
transformer
.
update_lora_params
(
"mit-han-lab/nunchaku-test-models/hand_drawn_game.safetensors"
)
transformer
.
set_lora_strength
(
1
)
prompt
=
(
"masterful impressionism oil painting titled 'the violinist', the composition follows the rule of thirds, "
"placing the violinist centrally in the frame. the subject is a young woman with fair skin and light blonde "
"hair is styled in a long, flowing hairstyle with natural waves. she is dressed in an opulent, "
"luxurious silver silk gown with a high waist and intricate gold detailing along the bodice. "
"the gown's texture is smooth and reflective. she holds a violin under her chin, "
"her right hand poised to play, and her left hand supporting the neck of the instrument. "
"she wears a delicate gold necklace with small, sparkling gemstones that catch the light. "
"her beautiful eyes focused on the viewer. the background features an elegantly furnished room "
"with classical late 19th century decor. to the left, there is a large, ornate portrait of "
"a man in a dark suit, set in a gilded frame. below this, a wooden desk with a closed book. "
"to the right, a red upholstered chair with a wooden frame is partially visible. "
"the room is bathed in natural light streaming through a window with red curtains, "
"creating a warm, inviting atmosphere. the lighting highlights the violinist, "
"casting soft shadows that enhance the depth and realism of the scene, highly aesthetic, "
"harmonious colors, impressioniststrokes, "
"<lora:style-impressionist_strokes-flux-by_daalis:1.0> <lora:image_upgrade-flux-by_zeronwo7829:1.0>"
)
image
=
pipeline
(
prompt
,
num_inference_steps
=
20
,
guidance_scale
=
3.5
).
images
[
0
]
image
.
save
(
f
"flux.1-dev-
{
precision
}
-1.png"
)
tests/flux/test_flux_dev_pulid.py
View file @
07f07563
import
gc
from
types
import
MethodType
import
numpy
as
np
...
...
@@ -15,6 +16,8 @@ from nunchaku.utils import get_precision, is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
def
test_flux_dev_pulid
():
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment