Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
8fbf418d
Unverified
Commit
8fbf418d
authored
Jun 12, 2025
by
Muyang Li
Committed by
GitHub
Jun 12, 2025
Browse files
feat: support kohya lora and loras with alphas (#459)
* kohya supported * add a test for the LoRA * add a gc
parent
46f4251a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
104 additions
and
2 deletions
+104
-2
nunchaku/lora/flux/diffusers_converter.py
nunchaku/lora/flux/diffusers_converter.py
+57
-2
nunchaku/lora/flux/nunchaku_converter.py
nunchaku/lora/flux/nunchaku_converter.py
+6
-0
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+38
-0
tests/flux/test_flux_dev_pulid.py
tests/flux/test_flux_dev_pulid.py
+3
-0
No files found.
nunchaku/lora/flux/diffusers_converter.py
View file @
8fbf418d
import
argparse
import
argparse
import
logging
import
os
import
os
import
warnings
import
torch
import
torch
from
diffusers.loaders
import
FluxLoraLoaderMixin
from
diffusers.loaders
import
FluxLoraLoaderMixin
...
@@ -9,6 +9,52 @@ from safetensors.torch import save_file
...
@@ -9,6 +9,52 @@ from safetensors.torch import save_file
from
.utils
import
load_state_dict_in_safetensors
from
.utils
import
load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level
=
os
.
getenv
(
"LOG_LEVEL"
,
"INFO"
).
upper
()
# Configure logging
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
,
logging
.
INFO
),
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
def
handle_kohya_lora
(
state_dict
:
dict
[
str
,
torch
.
Tensor
])
->
dict
[
str
,
torch
.
Tensor
]:
# first check if the state_dict is in the kohya format
# like: https://civitai.com/models/1118358?modelVersionId=1256866
if
any
([
not
k
.
startswith
(
"lora_transformer_"
)
for
k
in
state_dict
.
keys
()]):
return
state_dict
else
:
new_state_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
new_k
=
k
.
replace
(
"lora_transformer_"
,
"transformer."
)
new_k
=
new_k
.
replace
(
"norm_out_"
,
"norm_out."
)
new_k
=
new_k
.
replace
(
"time_text_embed_"
,
"time_text_embed."
)
new_k
=
new_k
.
replace
(
"guidance_embedder_"
,
"guidance_embedder."
)
new_k
=
new_k
.
replace
(
"text_embedder_"
,
"text_embedder."
)
new_k
=
new_k
.
replace
(
"timestep_embedder_"
,
"timestep_embedder."
)
new_k
=
new_k
.
replace
(
"single_transformer_blocks_"
,
"single_transformer_blocks."
)
new_k
=
new_k
.
replace
(
"_attn_"
,
".attn."
)
new_k
=
new_k
.
replace
(
"_norm_linear."
,
".norm.linear."
)
new_k
=
new_k
.
replace
(
"_proj_mlp."
,
".proj_mlp."
)
new_k
=
new_k
.
replace
(
"_proj_out."
,
".proj_out."
)
new_k
=
new_k
.
replace
(
"transformer_blocks_"
,
"transformer_blocks."
)
new_k
=
new_k
.
replace
(
"to_out_0."
,
"to_out.0."
)
new_k
=
new_k
.
replace
(
"_ff_context_net_0_proj."
,
".ff_context.net.0.proj."
)
new_k
=
new_k
.
replace
(
"_ff_context_net_2."
,
".ff_context.net.2."
)
new_k
=
new_k
.
replace
(
"_ff_net_0_proj."
,
".ff.net.0.proj."
)
new_k
=
new_k
.
replace
(
"_ff_net_2."
,
".ff.net.2."
)
new_k
=
new_k
.
replace
(
"_norm1_context_linear."
,
".norm1_context.linear."
)
new_k
=
new_k
.
replace
(
"_norm1_linear."
,
".norm1.linear."
)
new_k
=
new_k
.
replace
(
".lora_down."
,
".lora_A."
)
new_k
=
new_k
.
replace
(
".lora_up."
,
".lora_B."
)
new_state_dict
[
new_k
]
=
v
return
new_state_dict
def
to_diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
def
to_diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
input_lora
,
str
):
if
isinstance
(
input_lora
,
str
):
...
@@ -16,6 +62,8 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
...
@@ -16,6 +62,8 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
else
:
else
:
tensors
=
{
k
:
v
for
k
,
v
in
input_lora
.
items
()}
tensors
=
{
k
:
v
for
k
,
v
in
input_lora
.
items
()}
tensors
=
handle_kohya_lora
(
tensors
)
### convert the FP8 tensors to BF16
### convert the FP8 tensors to BF16
for
k
,
v
in
tensors
.
items
():
for
k
,
v
in
tensors
.
items
():
if
v
.
dtype
not
in
[
torch
.
float64
,
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
]:
if
v
.
dtype
not
in
[
torch
.
float64
,
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
]:
...
@@ -25,7 +73,14 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
...
@@ -25,7 +73,14 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
new_tensors
=
convert_unet_state_dict_to_peft
(
new_tensors
)
new_tensors
=
convert_unet_state_dict_to_peft
(
new_tensors
)
if
alphas
is
not
None
and
len
(
alphas
)
>
0
:
if
alphas
is
not
None
and
len
(
alphas
)
>
0
:
warnings
.
warn
(
"Alpha values are not used in the conversion to diffusers format."
)
for
k
,
v
in
alphas
.
items
():
key_A
=
k
.
replace
(
".alpha"
,
".lora_A.weight"
)
key_B
=
k
.
replace
(
".alpha"
,
".lora_B.weight"
)
assert
key_A
in
new_tensors
,
f
"Key
{
key_A
}
not found in new tensors."
assert
key_B
in
new_tensors
,
f
"Key
{
key_B
}
not found in new tensors."
rank
=
new_tensors
[
key_A
].
shape
[
0
]
assert
new_tensors
[
key_B
].
shape
[
1
]
==
rank
,
f
"Rank mismatch for
{
key_B
}
."
new_tensors
[
key_A
]
=
new_tensors
[
key_A
]
*
v
/
rank
if
output_path
is
not
None
:
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
...
...
nunchaku/lora/flux/nunchaku_converter.py
View file @
8fbf418d
...
@@ -12,8 +12,14 @@ from .diffusers_converter import to_diffusers
...
@@ -12,8 +12,14 @@ from .diffusers_converter import to_diffusers
from
.packer
import
NunchakuWeightPacker
from
.packer
import
NunchakuWeightPacker
from
.utils
import
is_nunchaku_format
,
pad
from
.utils
import
is_nunchaku_format
,
pad
# Get log level from environment variable (default to INFO)
log_level
=
os
.
getenv
(
"LOG_LEVEL"
,
"INFO"
).
upper
()
# Configure logging
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
,
logging
.
INFO
),
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# region utilities
# region utilities
...
...
tests/flux/test_flux_dev_loras.py
View file @
8fbf418d
import
pytest
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.utils
import
get_precision
,
is_turing
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
from
.utils
import
run_test
...
@@ -54,3 +57,38 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
...
@@ -54,3 +57,38 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
0.310
if
get_precision
()
==
"int4"
else
0.168
,
expected_lpips
=
0.310
if
get_precision
()
==
"int4"
else
0.168
,
)
)
def
test_kohya_lora
():
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
transformer
.
update_lora_params
(
"mit-han-lab/nunchaku-test-models/hand_drawn_game.safetensors"
)
transformer
.
set_lora_strength
(
1
)
prompt
=
(
"masterful impressionism oil painting titled 'the violinist', the composition follows the rule of thirds, "
"placing the violinist centrally in the frame. the subject is a young woman with fair skin and light blonde "
"hair is styled in a long, flowing hairstyle with natural waves. she is dressed in an opulent, "
"luxurious silver silk gown with a high waist and intricate gold detailing along the bodice. "
"the gown's texture is smooth and reflective. she holds a violin under her chin, "
"her right hand poised to play, and her left hand supporting the neck of the instrument. "
"she wears a delicate gold necklace with small, sparkling gemstones that catch the light. "
"her beautiful eyes focused on the viewer. the background features an elegantly furnished room "
"with classical late 19th century decor. to the left, there is a large, ornate portrait of "
"a man in a dark suit, set in a gilded frame. below this, a wooden desk with a closed book. "
"to the right, a red upholstered chair with a wooden frame is partially visible. "
"the room is bathed in natural light streaming through a window with red curtains, "
"creating a warm, inviting atmosphere. the lighting highlights the violinist, "
"casting soft shadows that enhance the depth and realism of the scene, highly aesthetic, "
"harmonious colors, impressioniststrokes, "
"<lora:style-impressionist_strokes-flux-by_daalis:1.0> <lora:image_upgrade-flux-by_zeronwo7829:1.0>"
)
image
=
pipeline
(
prompt
,
num_inference_steps
=
20
,
guidance_scale
=
3.5
).
images
[
0
]
image
.
save
(
f
"flux.1-dev-
{
precision
}
-1.png"
)
tests/flux/test_flux_dev_pulid.py
View file @
8fbf418d
import
gc
from
types
import
MethodType
from
types
import
MethodType
import
numpy
as
np
import
numpy
as
np
...
@@ -15,6 +16,8 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -15,6 +16,8 @@ from nunchaku.utils import get_precision, is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
def
test_flux_dev_pulid
():
def
test_flux_dev_pulid
():
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment