Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
2ede5f01
Commit
2ede5f01
authored
Apr 03, 2025
by
Muyang Li
Committed by
Zhekai Zhang
Apr 04, 2025
Browse files
Clean some codes and refract the tests
parent
83b7542d
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
742 additions
and
547 deletions
+742
-547
examples/sana_1600m-cache.py
examples/sana_1600m-cache.py
+1
-2
examples/sana_1600m.py
examples/sana_1600m.py
+1
-1
examples/sana_1600m_pag.py
examples/sana_1600m_pag.py
+1
-1
nunchaku/caching/diffusers_adapters/flux.py
nunchaku/caching/diffusers_adapters/flux.py
+3
-3
nunchaku/caching/utils.py
nunchaku/caching/utils.py
+16
-16
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+12
-6
nunchaku/models/transformers/transformer_sana.py
nunchaku/models/transformers/transformer_sana.py
+2
-1
nunchaku/models/transformers/utils.py
nunchaku/models/transformers/utils.py
+0
-19
nunchaku/utils.py
nunchaku/utils.py
+36
-0
tests/data/MJHQ/MJHQ.py
tests/data/MJHQ/MJHQ.py
+82
-33
tests/data/__init__.py
tests/data/__init__.py
+12
-2
tests/flux/test_flux_cache.py
tests/flux/test_flux_cache.py
+18
-29
tests/flux/test_flux_dev.py
tests/flux/test_flux_dev.py
+20
-159
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+89
-23
tests/flux/test_flux_memory.py
tests/flux/test_flux_memory.py
+8
-2
tests/flux/test_flux_schnell.py
tests/flux/test_flux_schnell.py
+17
-88
tests/flux/test_flux_tools.py
tests/flux/test_flux_tools.py
+122
-70
tests/flux/test_shuttle_jaguar.py
tests/flux/test_shuttle_jaguar.py
+14
-88
tests/flux/test_turing.py
tests/flux/test_turing.py
+30
-0
tests/flux/utils.py
tests/flux/utils.py
+258
-4
No files found.
examples/
int4-
sana_1600m-cache.py
→
examples/sana_1600m-cache.py
View file @
2ede5f01
...
@@ -4,7 +4,6 @@ from diffusers import SanaPipeline
...
@@ -4,7 +4,6 @@ from diffusers import SanaPipeline
from
nunchaku
import
NunchakuSanaTransformer2DModel
from
nunchaku
import
NunchakuSanaTransformer2DModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
)
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
)
pipe
=
SanaPipeline
.
from_pretrained
(
pipe
=
SanaPipeline
.
from_pretrained
(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
,
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
,
...
@@ -29,4 +28,4 @@ image = pipe(
...
@@ -29,4 +28,4 @@ image = pipe(
generator
=
torch
.
Generator
().
manual_seed
(
42
),
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
).
images
[
0
]
image
.
save
(
"sana_1600m.png"
)
image
.
save
(
"sana_1600m
-int4
.png"
)
examples/
int4-
sana_1600m.py
→
examples/sana_1600m.py
View file @
2ede5f01
...
@@ -23,4 +23,4 @@ image = pipe(
...
@@ -23,4 +23,4 @@ image = pipe(
generator
=
torch
.
Generator
().
manual_seed
(
42
),
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
).
images
[
0
]
image
.
save
(
"sana_1600m.png"
)
image
.
save
(
"sana_1600m
-int4
.png"
)
examples/
int4-
sana_1600m_pag.py
→
examples/sana_1600m_pag.py
View file @
2ede5f01
...
@@ -24,4 +24,4 @@ image = pipe(
...
@@ -24,4 +24,4 @@ image = pipe(
pag_scale
=
2.0
,
pag_scale
=
2.0
,
num_inference_steps
=
20
,
num_inference_steps
=
20
,
).
images
[
0
]
).
images
[
0
]
image
.
save
(
"sana_1600m_pag.png"
)
image
.
save
(
"sana_1600m_pag
-int4
.png"
)
nunchaku/caching/diffusers_adapters/flux.py
View file @
2ede5f01
import
functools
import
functools
import
unittest
import
unittest
import
torch
from
diffusers
import
DiffusionPipeline
,
FluxTransformer2DModel
from
diffusers
import
DiffusionPipeline
,
FluxTransformer2DModel
from
torch
import
nn
from
...caching
import
utils
from
...caching
import
utils
...
@@ -11,7 +11,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
...
@@ -11,7 +11,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
if
getattr
(
transformer
,
"_is_cached"
,
False
):
if
getattr
(
transformer
,
"_is_cached"
,
False
):
return
transformer
return
transformer
cached_transformer_blocks
=
torch
.
nn
.
ModuleList
(
cached_transformer_blocks
=
nn
.
ModuleList
(
[
[
utils
.
FluxCachedTransformerBlocks
(
utils
.
FluxCachedTransformerBlocks
(
transformer
=
transformer
,
transformer
=
transformer
,
...
@@ -20,7 +20,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
...
@@ -20,7 +20,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
)
)
]
]
)
)
dummy_single_transformer_blocks
=
torch
.
nn
.
ModuleList
()
dummy_single_transformer_blocks
=
nn
.
ModuleList
()
original_forward
=
transformer
.
forward
original_forward
=
transformer
.
forward
...
...
nunchaku/caching/utils.py
View file @
2ede5f01
...
@@ -94,7 +94,6 @@ def apply_prev_hidden_states_residual(
...
@@ -94,7 +94,6 @@ def apply_prev_hidden_states_residual(
encoder_hidden_states
=
encoder_hidden_states_residual
+
encoder_hidden_states
encoder_hidden_states
=
encoder_hidden_states_residual
+
encoder_hidden_states
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
return
hidden_states
,
encoder_hidden_states
return
hidden_states
,
encoder_hidden_states
...
@@ -109,6 +108,7 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
...
@@ -109,6 +108,7 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
)
)
return
can_use_cache
return
can_use_cache
class
SanaCachedTransformerBlocks
(
nn
.
Module
):
class
SanaCachedTransformerBlocks
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -123,7 +123,8 @@ class SanaCachedTransformerBlocks(nn.Module):
...
@@ -123,7 +123,8 @@ class SanaCachedTransformerBlocks(nn.Module):
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
verbose
=
verbose
self
.
verbose
=
verbose
def
forward
(
self
,
def
forward
(
self
,
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
encoder_hidden_states
,
encoder_hidden_states
,
...
@@ -135,8 +136,7 @@ class SanaCachedTransformerBlocks(nn.Module):
...
@@ -135,8 +136,7 @@ class SanaCachedTransformerBlocks(nn.Module):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
if
self
.
residual_diff_threshold
<=
0.0
or
batch_size
>
2
:
if
self
.
residual_diff_threshold
<=
0.0
or
batch_size
>
2
:
if
batch_size
>
2
:
if
batch_size
>
2
:
print
(
"Batch size > 2 (for SANA CFG)"
print
(
"Batch size > 2 (for SANA CFG)"
" currently not supported"
)
" currently not supported"
)
first_transformer_block
=
self
.
transformer_blocks
[
0
]
first_transformer_block
=
self
.
transformer_blocks
[
0
]
hidden_states
=
first_transformer_block
(
hidden_states
=
first_transformer_block
(
...
@@ -199,15 +199,15 @@ class SanaCachedTransformerBlocks(nn.Module):
...
@@ -199,15 +199,15 @@ class SanaCachedTransformerBlocks(nn.Module):
return
hidden_states
return
hidden_states
def
call_remaining_transformer_blocks
(
def
call_remaining_transformer_blocks
(
self
,
self
,
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
encoder_hidden_states
,
encoder_hidden_states
,
encoder_attention_mask
=
None
,
encoder_attention_mask
=
None
,
timestep
=
None
,
timestep
=
None
,
post_patch_height
=
None
,
post_patch_height
=
None
,
post_patch_width
=
None
post_patch_width
=
None
,
):
):
first_transformer_block
=
self
.
transformer_blocks
[
0
]
first_transformer_block
=
self
.
transformer_blocks
[
0
]
original_hidden_states
=
hidden_states
original_hidden_states
=
hidden_states
...
@@ -219,7 +219,7 @@ class SanaCachedTransformerBlocks(nn.Module):
...
@@ -219,7 +219,7 @@ class SanaCachedTransformerBlocks(nn.Module):
timestep
=
timestep
,
timestep
=
timestep
,
height
=
post_patch_height
,
height
=
post_patch_height
,
width
=
post_patch_width
,
width
=
post_patch_width
,
skip_first_layer
=
True
skip_first_layer
=
True
,
)
)
hidden_states_residual
=
hidden_states
-
original_hidden_states
hidden_states_residual
=
hidden_states
-
original_hidden_states
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
2ede5f01
...
@@ -13,11 +13,11 @@ from packaging.version import Version
...
@@ -13,11 +13,11 @@ from packaging.version import Version
from
safetensors.torch
import
load_file
,
save_file
from
safetensors.torch
import
load_file
,
save_file
from
torch
import
nn
from
torch
import
nn
from
.utils
import
get_precision
,
NunchakuModelLoaderMixin
,
pad_tensor
from
.utils
import
NunchakuModelLoaderMixin
,
pad_tensor
from
..._C
import
QuantizedFluxModel
,
utils
as
cutils
from
..._C
import
QuantizedFluxModel
,
utils
as
cutils
from
...lora.flux.nunchaku_converter
import
fuse_vectors
,
to_nunchaku
from
...lora.flux.nunchaku_converter
import
fuse_vectors
,
to_nunchaku
from
...lora.flux.utils
import
is_nunchaku_format
from
...lora.flux.utils
import
is_nunchaku_format
from
...utils
import
load_state_dict_in_safetensors
from
...utils
import
get_precision
,
load_state_dict_in_safetensors
SVD_RANK
=
32
SVD_RANK
=
32
...
@@ -127,7 +127,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -127,7 +127,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
image_rotary_emb
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
joint_attention_kwargs
=
None
,
joint_attention_kwargs
=
None
,
controlnet_block_samples
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
controlnet_single_block_samples
=
None
,
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
...
@@ -159,8 +159,14 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -159,8 +159,14 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
idx
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
idx
,
controlnet_block_samples
,
controlnet_single_block_samples
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
controlnet_block_samples
,
controlnet_single_block_samples
,
)
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
...
@@ -578,7 +584,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -578,7 +584,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
image_rotary_emb
=
image_rotary_emb
,
image_rotary_emb
=
image_rotary_emb
,
joint_attention_kwargs
=
joint_attention_kwargs
,
joint_attention_kwargs
=
joint_attention_kwargs
,
controlnet_block_samples
=
controlnet_block_samples
,
controlnet_block_samples
=
controlnet_block_samples
,
controlnet_single_block_samples
=
controlnet_single_block_samples
controlnet_single_block_samples
=
controlnet_single_block_samples
,
)
)
hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
hidden_states
=
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
hidden_states
=
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
...
...
nunchaku/models/transformers/transformer_sana.py
View file @
2ede5f01
...
@@ -8,7 +8,8 @@ from safetensors.torch import load_file
...
@@ -8,7 +8,8 @@ from safetensors.torch import load_file
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
.utils
import
get_precision
,
NunchakuModelLoaderMixin
from
.utils
import
NunchakuModelLoaderMixin
from
...utils
import
get_precision
from
..._C
import
QuantizedSanaModel
,
utils
as
cutils
from
..._C
import
QuantizedSanaModel
,
utils
as
cutils
SVD_RANK
=
32
SVD_RANK
=
32
...
...
nunchaku/models/transformers/utils.py
View file @
2ede5f01
import
os
import
os
import
warnings
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
torch
import
torch
...
@@ -82,21 +81,3 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
...
@@ -82,21 +81,3 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
result
.
fill_
(
fill
)
result
.
fill_
(
fill
)
result
[[
slice
(
0
,
extent
)
for
extent
in
tensor
.
shape
]]
=
tensor
result
[[
slice
(
0
,
extent
)
for
extent
in
tensor
.
shape
]]
=
tensor
return
result
return
result
def
get_precision
(
precision
:
str
,
device
:
str
|
torch
.
device
,
pretrained_model_name_or_path
:
str
|
None
=
None
)
->
str
:
assert
precision
in
(
"auto"
,
"int4"
,
"fp4"
)
if
precision
==
"auto"
:
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
if
device
.
index
is
None
else
device
.
index
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
if
pretrained_model_name_or_path
is
not
None
:
if
precision
==
"int4"
:
if
"fp4"
in
pretrained_model_name_or_path
:
warnings
.
warn
(
"The model may be quantized to fp4, but you are loading it with int4 precision."
)
elif
precision
==
"fp4"
:
if
"int4"
in
pretrained_model_name_or_path
:
warnings
.
warn
(
"The model may be quantized to int4, but you are loading it with fp4 precision."
)
return
precision
nunchaku/utils.py
View file @
2ede5f01
import
os
import
os
import
warnings
import
safetensors
import
safetensors
import
torch
import
torch
...
@@ -69,3 +70,38 @@ def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str =
...
@@ -69,3 +70,38 @@ def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str =
filtered state dict.
filtered state dict.
"""
"""
return
{
k
.
removeprefix
(
filter_prefix
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
filter_prefix
)}
return
{
k
.
removeprefix
(
filter_prefix
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
filter_prefix
)}
def
get_precision
(
precision
:
str
=
"auto"
,
device
:
str
|
torch
.
device
=
"cuda"
,
pretrained_model_name_or_path
:
str
|
None
=
None
)
->
str
:
assert
precision
in
(
"auto"
,
"int4"
,
"fp4"
)
if
precision
==
"auto"
:
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
if
device
.
index
is
None
else
device
.
index
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
if
pretrained_model_name_or_path
is
not
None
:
if
precision
==
"int4"
:
if
"fp4"
in
pretrained_model_name_or_path
:
warnings
.
warn
(
"The model may be quantized to fp4, but you are loading it with int4 precision."
)
elif
precision
==
"fp4"
:
if
"int4"
in
pretrained_model_name_or_path
:
warnings
.
warn
(
"The model may be quantized to int4, but you are loading it with fp4 precision."
)
return
precision
def
is_turing
(
device
:
str
|
torch
.
device
=
"cuda"
)
->
bool
:
"""Check if the current GPU is a Turing GPU.
Returns:
`bool`:
True if the current GPU is a Turing GPU, False otherwise.
"""
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
device_id
=
0
if
device
.
index
is
None
else
device
.
index
capability
=
torch
.
cuda
.
get_device_capability
(
device_id
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
return
sm
==
"75"
tests/data/MJHQ/MJHQ.py
View file @
2ede5f01
...
@@ -3,6 +3,7 @@ import os
...
@@ -3,6 +3,7 @@ import os
import
random
import
random
import
datasets
import
datasets
import
yaml
from
PIL
import
Image
from
PIL
import
Image
_CITATION
=
"""
\
_CITATION
=
"""
\
...
@@ -32,6 +33,8 @@ IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/
...
@@ -32,6 +33,8 @@ IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/
META_URL
=
"https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
META_URL
=
"https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
CONTROL_URL
=
"https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/MJHQ-5000.zip"
class
MJHQConfig
(
datasets
.
BuilderConfig
):
class
MJHQConfig
(
datasets
.
BuilderConfig
):
def
__init__
(
self
,
max_dataset_size
:
int
=
-
1
,
return_gt
:
bool
=
False
,
**
kwargs
):
def
__init__
(
self
,
max_dataset_size
:
int
=
-
1
,
return_gt
:
bool
=
False
,
**
kwargs
):
...
@@ -46,11 +49,14 @@ class MJHQConfig(datasets.BuilderConfig):
...
@@ -46,11 +49,14 @@ class MJHQConfig(datasets.BuilderConfig):
self
.
return_gt
=
return_gt
self
.
return_gt
=
return_gt
class
DCI
(
datasets
.
GeneratorBasedBuilder
):
class
MJHQ
(
datasets
.
GeneratorBasedBuilder
):
VERSION
=
datasets
.
Version
(
"0.0.0"
)
VERSION
=
datasets
.
Version
(
"0.0.0"
)
BUILDER_CONFIG_CLASS
=
MJHQConfig
BUILDER_CONFIG_CLASS
=
MJHQConfig
BUILDER_CONFIGS
=
[
MJHQConfig
(
name
=
"MJHQ"
,
version
=
VERSION
,
description
=
"MJHQ-30K full dataset"
)]
BUILDER_CONFIGS
=
[
MJHQConfig
(
name
=
"MJHQ"
,
version
=
VERSION
,
description
=
"MJHQ-30K full dataset"
),
MJHQConfig
(
name
=
"MJHQ-control"
,
version
=
VERSION
,
description
=
"MJHQ-5K with controls"
),
]
DEFAULT_CONFIG_NAME
=
"MJHQ"
DEFAULT_CONFIG_NAME
=
"MJHQ"
def
_info
(
self
):
def
_info
(
self
):
...
@@ -64,6 +70,10 @@ class DCI(datasets.GeneratorBasedBuilder):
...
@@ -64,6 +70,10 @@ class DCI(datasets.GeneratorBasedBuilder):
"image_root"
:
datasets
.
Value
(
"string"
),
"image_root"
:
datasets
.
Value
(
"string"
),
"image_path"
:
datasets
.
Value
(
"string"
),
"image_path"
:
datasets
.
Value
(
"string"
),
"split"
:
datasets
.
Value
(
"string"
),
"split"
:
datasets
.
Value
(
"string"
),
"canny_image_path"
:
datasets
.
Value
(
"string"
),
"cropped_image_path"
:
datasets
.
Value
(
"string"
),
"depth_image_path"
:
datasets
.
Value
(
"string"
),
"mask_image_path"
:
datasets
.
Value
(
"string"
),
}
}
)
)
return
datasets
.
DatasetInfo
(
return
datasets
.
DatasetInfo
(
...
@@ -71,6 +81,7 @@ class DCI(datasets.GeneratorBasedBuilder):
...
@@ -71,6 +81,7 @@ class DCI(datasets.GeneratorBasedBuilder):
)
)
def
_split_generators
(
self
,
dl_manager
:
datasets
.
download
.
DownloadManager
):
def
_split_generators
(
self
,
dl_manager
:
datasets
.
download
.
DownloadManager
):
if
self
.
config
.
name
==
"MJHQ"
:
meta_path
=
dl_manager
.
download
(
META_URL
)
meta_path
=
dl_manager
.
download
(
META_URL
)
image_root
=
dl_manager
.
download_and_extract
(
IMAGE_URL
)
image_root
=
dl_manager
.
download_and_extract
(
IMAGE_URL
)
return
[
return
[
...
@@ -78,9 +89,19 @@ class DCI(datasets.GeneratorBasedBuilder):
...
@@ -78,9 +89,19 @@ class DCI(datasets.GeneratorBasedBuilder):
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"meta_path"
:
meta_path
,
"image_root"
:
image_root
}
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"meta_path"
:
meta_path
,
"image_root"
:
image_root
}
),
),
]
]
else
:
assert
self
.
config
.
name
==
"MJHQ-control"
control_root
=
dl_manager
.
download_and_extract
(
CONTROL_URL
)
control_root
=
os
.
path
.
join
(
control_root
,
"MJHQ-5000"
)
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TRAIN
,
gen_kwargs
=
{
"meta_path"
:
os
.
path
.
join
(
control_root
,
"prompts.yaml"
),
"image_root"
:
control_root
},
),
]
def
_generate_examples
(
self
,
meta_path
:
str
,
image_root
:
str
):
def
_generate_examples
(
self
,
meta_path
:
str
,
image_root
:
str
):
if
self
.
config
.
name
==
"MJHQ"
:
with
open
(
meta_path
,
"r"
)
as
f
:
with
open
(
meta_path
,
"r"
)
as
f
:
meta
=
json
.
load
(
f
)
meta
=
json
.
load
(
f
)
...
@@ -103,4 +124,32 @@ class DCI(datasets.GeneratorBasedBuilder):
...
@@ -103,4 +124,32 @@ class DCI(datasets.GeneratorBasedBuilder):
"image_root"
:
image_root
,
"image_root"
:
image_root
,
"image_path"
:
image_path
,
"image_path"
:
image_path
,
"split"
:
self
.
config
.
name
,
"split"
:
self
.
config
.
name
,
"canny_image_path"
:
None
,
"cropped_image_path"
:
None
,
"depth_image_path"
:
None
,
"mask_image_path"
:
None
,
}
else
:
assert
self
.
config
.
name
==
"MJHQ-control"
meta
=
yaml
.
safe_load
(
open
(
meta_path
,
"r"
))
names
=
list
(
meta
.
keys
())
if
self
.
config
.
max_dataset_size
>
0
:
random
.
Random
(
0
).
shuffle
(
names
)
names
=
names
[:
self
.
config
.
max_dataset_size
]
names
=
sorted
(
names
)
for
i
,
name
in
enumerate
(
names
):
prompt
=
meta
[
name
]
yield
i
,
{
"filename"
:
name
,
"category"
:
None
,
"image"
:
None
,
"prompt"
:
prompt
,
"meta_path"
:
meta_path
,
"image_root"
:
image_root
,
"image_path"
:
os
.
path
.
join
(
image_root
,
"images"
,
f
"
{
name
}
.png"
),
"split"
:
self
.
config
.
name
,
"canny_image_path"
:
os
.
path
.
join
(
image_root
,
"canny_images"
,
f
"
{
name
}
.png"
),
"cropped_image_path"
:
os
.
path
.
join
(
image_root
,
"cropped_images"
,
f
"
{
name
}
.png"
),
"depth_image_path"
:
os
.
path
.
join
(
image_root
,
"depth_images"
,
f
"
{
name
}
.png"
),
"mask_image_path"
:
os
.
path
.
join
(
image_root
,
"mask_images"
,
f
"
{
name
}
.png"
),
}
}
tests/data/__init__.py
View file @
2ede5f01
...
@@ -3,9 +3,16 @@ import random
...
@@ -3,9 +3,16 @@ import random
import
datasets
import
datasets
import
yaml
import
yaml
from
huggingface_hub
import
snapshot_download
from
nunchaku.utils
import
fetch_or_download
from
nunchaku.utils
import
fetch_or_download
__all__
=
[
"get_dataset"
]
__all__
=
[
"get_dataset"
,
"load_dataset_yaml"
,
"download_hf_dataset"
]
def
download_hf_dataset
(
repo_id
:
str
=
"mit-han-lab/nunchaku-test"
,
local_dir
:
str
|
None
=
None
)
->
str
:
path
=
snapshot_download
(
repo_id
=
repo_id
,
repo_type
=
"dataset"
,
local_dir
=
local_dir
)
return
path
def
load_dataset_yaml
(
meta_path
:
str
,
max_dataset_size
:
int
=
-
1
,
repeat
:
int
=
4
)
->
dict
:
def
load_dataset_yaml
(
meta_path
:
str
,
max_dataset_size
:
int
=
-
1
,
repeat
:
int
=
4
)
->
dict
:
...
@@ -46,10 +53,13 @@ def get_dataset(
...
@@ -46,10 +53,13 @@ def get_dataset(
path
=
os
.
path
.
join
(
prefix
,
f
"
{
name
}
"
)
path
=
os
.
path
.
join
(
prefix
,
f
"
{
name
}
"
)
if
name
==
"MJHQ"
:
if
name
==
"MJHQ"
:
dataset
=
datasets
.
load_dataset
(
path
,
return_gt
=
return_gt
,
**
kwargs
)
dataset
=
datasets
.
load_dataset
(
path
,
return_gt
=
return_gt
,
**
kwargs
)
elif
name
==
"MJHQ-control"
:
kwargs
[
"name"
]
=
"MJHQ-control"
dataset
=
datasets
.
load_dataset
(
os
.
path
.
join
(
prefix
,
"MJHQ"
),
return_gt
=
return_gt
,
**
kwargs
)
else
:
else
:
dataset
=
datasets
.
Dataset
.
from_dict
(
dataset
=
datasets
.
Dataset
.
from_dict
(
load_dataset_yaml
(
load_dataset_yaml
(
fetch_or_download
(
f
"mit-han-lab/
nunchaku-test
/
{
name
}
.yaml"
,
repo_type
=
"dataset"
),
fetch_or_download
(
f
"mit-han-lab/
svdquant-datasets
/
{
name
}
.yaml"
,
repo_type
=
"dataset"
),
max_dataset_size
=
max_dataset_size
,
max_dataset_size
=
max_dataset_size
,
repeat
=
1
,
repeat
=
1
,
),
),
...
...
tests/flux/test_flux_cache.py
View file @
2ede5f01
import
pytest
import
pytest
from
.test_flux_dev
import
run_test_flux_dev
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests for Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,num_inference_steps,
cache_threshold,lora_name,use_qencoder,cpu_offload
,expected_lpips"
,
"
cache_threshold,
height,width,num_inference_steps,
lora_name,lora_strength
,expected_lpips"
,
[
[
# (1024, 1024, 50, 0, None, False, False, 0.5), # 13min20s 5min55s 0.19539418816566467
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.26
),
# (1024, 1024, 50, 0.05, None, False, True, 0.5), # 7min11s 0.21917256712913513
(
0.12
,
512
,
2048
,
30
,
"anime"
,
1
,
0.4
),
# (1024, 1024, 50, 0.12, None, False, True, 0.5), # 2min58s, 0.24101486802101135
# (1024, 1024, 50, 0.2, None, False, True, 0.5), # 2min23s, 0.3101634383201599
# (1024, 1024, 50, 0.5, None, False, True, 0.5), # 1min44s 0.6543852090835571
# (1024, 1024, 30, 0, None, False, False, 0.5), # 8min2s 3min40s 0.2141970843076706
# (1024, 1024, 30, 0.05, None, False, True, 0.5), # 4min57 0.21297718584537506
# (1024, 1024, 30, 0.12, None, False, True, 0.5), # 2min34 0.25963714718818665
# (1024, 1024, 30, 0.2, None, False, True, 0.5), # 1min51 0.31409069895744324
# (1024, 1024, 20, 0, None, False, False, 0.5), # 5min25 2min29 0.18987375497817993
# (1024, 1024, 20, 0.05, None, False, True, 0.5), # 3min3 0.17194810509681702
# (1024, 1024, 20, 0.12, None, False, True, 0.5), # 2min15 0.19407868385314941
# (1024, 1024, 20, 0.2, None, False, True, 0.5), # 1min48 0.2832985818386078
(
1024
,
1024
,
30
,
0.12
,
None
,
False
,
False
,
0.26
),
(
512
,
2048
,
30
,
0.12
,
"anime"
,
True
,
False
,
0.4
),
],
],
)
)
def
test_flux_dev_base
(
def
test_flux_dev_loras
(
cache_threshold
:
float
,
height
:
int
,
height
:
int
,
width
:
int
,
width
:
int
,
num_inference_steps
:
int
,
num_inference_steps
:
int
,
cache_threshold
:
float
,
lora_name
:
str
,
lora_name
:
str
|
None
,
lora_strength
:
float
,
use_qencoder
:
bool
,
cpu_offload
:
bool
,
expected_lpips
:
float
,
expected_lpips
:
float
,
):
):
run_test_flux_dev
(
run_test
(
precision
=
"int4"
,
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
dataset_name
=
"MJHQ"
if
lora_name
is
None
else
lora_name
,
height
=
height
,
height
=
height
,
width
=
width
,
width
=
width
,
num_inference_steps
=
num_inference_steps
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
3.5
,
guidance_scale
=
3.5
,
use_qencoder
=
use_qencoder
,
use_qencoder
=
False
,
cpu_offload
=
cpu_offload
,
cpu_offload
=
False
,
lora_name
=
lora_name
,
lora_name
s
=
lora_name
,
lora_s
cale
=
1
,
lora_s
trengths
=
lora_strength
,
cache_threshold
=
cache_threshold
,
cache_threshold
=
cache_threshold
,
max_dataset_size
=
16
,
expected_lpips
=
expected_lpips
,
expected_lpips
=
expected_lpips
,
)
)
tests/flux/test_flux_dev.py
View file @
2ede5f01
import
os
import
pytest
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
peft.tuners
import
lora
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
from
nunchaku.lora.flux
import
convert_to_nunchaku_flux_lowrank_dict
,
is_nunchaku_format
,
to_diffusers
from
.utils
import
run_pipeline
from
..data
import
get_dataset
from
..utils
import
already_generate
,
compute_lpips
LORA_PATH_MAP
=
{
from
nunchaku.utils
import
get_precision
,
is_turing
"hypersd8"
:
"ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors"
,
from
.utils
import
run_test
"realism"
:
"XLabs-AI/flux-RealismLora/lora.safetensors"
,
"ghibsky"
:
"aleksa-codes/flux-ghibsky-illustration/lora.safetensors"
,
"anime"
:
"alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors"
,
"sketch"
:
"Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors"
,
"yarn"
:
"linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors"
,
"haunted_linework"
:
"alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors"
,
}
def
run_test_flux_dev
(
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests for Turing GPUs"
)
precision
:
str
,
@
pytest
.
mark
.
parametrize
(
height
:
int
,
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips"
,
width
:
int
,
[
num_inference_steps
:
int
,
(
1024
,
1024
,
50
,
"flashattn2"
,
False
,
0.226
),
guidance_scale
:
float
,
(
2048
,
512
,
25
,
"nunchaku-fp16"
,
False
,
0.243
),
use_qencoder
:
bool
,
],
cpu_offload
:
bool
,
)
lora_name
:
str
|
None
,
def
test_flux_dev
(
lora_scale
:
float
,
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
cache_threshold
:
float
,
max_dataset_size
:
int
,
expected_lpips
:
float
,
):
):
save_root
=
os
.
path
.
join
(
run_test
(
"results"
,
precision
=
get_precision
(),
"dev"
,
model_name
=
"flux.1-dev"
,
f
"w
{
width
}
h
{
height
}
t
{
num_inference_steps
}
g
{
guidance_scale
}
"
height
=
height
,
+
(
f
"-
{
lora_name
}
_
{
lora_scale
:.
1
f
}
"
if
lora_name
else
""
),
width
=
width
,
)
num_inference_steps
=
num_inference_steps
,
dataset
=
get_dataset
(
attention_impl
=
attention_impl
,
name
=
"MJHQ"
if
lora_name
in
[
None
,
"hypersd8"
]
else
lora_name
,
max_dataset_size
=
max_dataset_size
)
save_dir_16bit
=
os
.
path
.
join
(
save_root
,
"bf16"
)
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
if
lora_name
is
not
None
:
pipeline
.
load_lora_weights
(
os
.
path
.
dirname
(
LORA_PATH_MAP
[
lora_name
]),
weight_name
=
os
.
path
.
basename
(
LORA_PATH_MAP
[
lora_name
]),
adapter_name
=
"lora"
,
)
for
n
,
m
in
pipeline
.
transformer
.
named_modules
():
if
isinstance
(
m
,
lora
.
LoraLayer
):
for
name
in
m
.
scaling
.
keys
():
m
.
scaling
[
name
]
=
lora_scale
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_16bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
name
=
precision
name
+=
"-qencoder"
if
use_qencoder
else
""
name
+=
"-offload"
if
cpu_offload
else
""
name
+=
f
"-cache
{
cache_threshold
:.
2
f
}
"
if
cache_threshold
>
0
else
""
save_dir_4bit
=
os
.
path
.
join
(
save_root
,
name
)
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{}
if
precision
==
"int4"
:
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
,
offload
=
cpu_offload
)
else
:
assert
precision
==
"fp4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-fp4-flux.1-dev"
,
precision
=
"fp4"
,
offload
=
cpu_offload
)
if
lora_name
is
not
None
:
lora_path
=
LORA_PATH_MAP
[
lora_name
]
transformer
.
update_lora_params
(
lora_path
)
transformer
.
set_lora_strength
(
lora_scale
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
use_qencoder
:
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
if
cache_threshold
>
0
:
apply_cache_on_pipe
(
pipeline
,
residual_diff_threshold
=
cache_threshold
)
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_4bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.05
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
[
False
,
True
])
def
test_flux_dev_base
(
cpu_offload
:
bool
):
run_test_flux_dev
(
precision
=
"int4"
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
cpu_offload
=
cpu_offload
,
cpu_offload
=
cpu_offload
,
lora_name
=
None
,
expected_lpips
=
expected_lpips
,
lora_scale
=
0
,
cache_threshold
=
0
,
max_dataset_size
=
8
,
expected_lpips
=
0.16
,
)
def
test_flux_dev_qencoder_800x600
():
run_test_flux_dev
(
precision
=
"int4"
,
height
=
800
,
width
=
600
,
num_inference_steps
=
50
,
guidance_scale
=
3.5
,
use_qencoder
=
True
,
cpu_offload
=
False
,
lora_name
=
None
,
lora_scale
=
0
,
cache_threshold
=
0
,
max_dataset_size
=
8
,
expected_lpips
=
0.36
,
)
)
tests/flux/test_flux_dev_loras.py
View file @
2ede5f01
import
pytest
import
pytest
from
tests.flux.test_flux_dev
import
run_test_flux_dev
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"num_inference_steps,lora_name,lora_s
cale
,cpu_offload,expected_lpips"
,
"num_inference_steps,lora_name,lora_s
trength
,cpu_offload,expected_lpips"
,
[
[
(
25
,
"realism"
,
0.9
,
Fals
e
,
0.17
),
(
25
,
"realism"
,
0.9
,
Tru
e
,
0.17
8
),
(
25
,
"ghibsky"
,
1
,
False
,
0.16
),
(
25
,
"ghibsky"
,
1
,
False
,
0.16
4
),
(
28
,
"anime"
,
1
,
False
,
0.2
7
),
(
28
,
"anime"
,
1
,
False
,
0.2
84
),
(
24
,
"sketch"
,
1
,
Fals
e
,
0.3
5
),
(
24
,
"sketch"
,
1
,
Tru
e
,
0.
22
3
),
(
28
,
"yarn"
,
1
,
False
,
0.2
2
),
(
28
,
"yarn"
,
1
,
False
,
0.2
11
),
(
25
,
"haunted_linework"
,
1
,
Fals
e
,
0.3
4
),
(
25
,
"haunted_linework"
,
1
,
Tru
e
,
0.3
17
),
],
],
)
)
def
test_flux_dev_loras
(
num_inference_steps
,
lora_name
,
lora_scale
,
cpu_offload
,
expected_lpips
):
def
test_flux_dev_loras
(
num_inference_steps
,
lora_name
,
lora_strength
,
cpu_offload
,
expected_lpips
):
run_test_flux_dev
(
run_test
(
precision
=
"int4"
,
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
dataset_name
=
lora_name
,
height
=
1024
,
height
=
1024
,
width
=
1024
,
width
=
1024
,
num_inference_steps
=
num_inference_steps
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
3.5
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
use_qencoder
=
False
,
cpu_offload
=
cpu_offload
,
cpu_offload
=
cpu_offload
,
lora_name
=
lora_name
,
lora_name
s
=
lora_name
,
lora_s
cale
=
lora_scale
,
lora_s
trengths
=
lora_strength
,
cache_threshold
=
0
,
cache_threshold
=
0
,
max_dataset_size
=
8
,
expected_lpips
=
expected_lpips
,
expected_lpips
=
expected_lpips
,
)
)
def
test_flux_dev_hypersd8_1080x1920
():
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
run_test_flux_dev
(
def
test_flux_dev_hypersd8_1536x2048
():
precision
=
"int4"
,
run_test
(
height
=
1080
,
precision
=
get_precision
(),
width
=
1920
,
model_name
=
"flux.1-dev"
,
dataset_name
=
"MJHQ"
,
height
=
1536
,
width
=
2048
,
num_inference_steps
=
8
,
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
use_qencoder
=
False
,
cpu_offload
=
False
,
attention_impl
=
"nunchaku-fp16"
,
lora_name
=
"hypersd8"
,
cpu_offload
=
True
,
lora_scale
=
0.125
,
lora_names
=
"hypersd8"
,
lora_strengths
=
0.125
,
cache_threshold
=
0
,
expected_lpips
=
0.291
,
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
def
test_flux_dev_turbo8_2048x2048
():
run_test
(
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
dataset_name
=
"MJHQ"
,
height
=
2048
,
width
=
2048
,
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
True
,
lora_names
=
"turbo8"
,
lora_strengths
=
1
,
cache_threshold
=
0
,
expected_lpips
=
0.189
,
)
# lora composition
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
def
test_flux_dev_turbo8_yarn_2048x1024
():
run_test
(
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
dataset_name
=
"yarn"
,
height
=
2048
,
width
=
1024
,
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
cpu_offload
=
True
,
lora_names
=
[
"turbo8"
,
"yarn"
],
lora_strengths
=
[
1
,
1
],
cache_threshold
=
0
,
expected_lpips
=
0.252
,
)
# large rank loras
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
def
test_flux_dev_turbo8_yarn_1024x1024
():
run_test
(
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
dataset_name
=
"ghibsky"
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
cpu_offload
=
True
,
lora_names
=
[
"realism"
,
"ghibsky"
,
"anime"
,
"sketch"
,
"yarn"
,
"haunted_linework"
,
"turbo8"
],
lora_strengths
=
[
0
,
1
,
0
,
0
,
0
,
0
,
1
],
cache_threshold
=
0
,
cache_threshold
=
0
,
max_dataset_size
=
8
,
expected_lpips
=
0.44
,
expected_lpips
=
0.44
,
)
)
tests/flux/test_flux_memory.py
View file @
2ede5f01
...
@@ -3,8 +3,10 @@ import torch
...
@@ -3,8 +3,10 @@ import torch
from
diffusers
import
FluxPipeline
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku.utils
import
get_precision
,
is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"use_qencoder,cpu_offload,memory_limit"
,
"use_qencoder,cpu_offload,memory_limit"
,
[
[
...
@@ -15,10 +17,12 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
...
@@ -15,10 +17,12 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
],
],
)
)
def
test_flux_schnell_memory
(
use_qencoder
:
bool
,
cpu_offload
:
bool
,
memory_limit
:
float
):
def
test_flux_schnell_memory
(
use_qencoder
:
bool
,
cpu_offload
:
bool
,
memory_limit
:
float
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
()
precision
=
get_precision
()
pipeline_init_kwargs
=
{
pipeline_init_kwargs
=
{
"transformer"
:
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"transformer"
:
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-
int4
-flux.1-schnell"
,
offload
=
cpu_offload
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
offload
=
cpu_offload
)
)
}
}
if
use_qencoder
:
if
use_qencoder
:
...
@@ -26,10 +30,12 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
...
@@ -26,10 +30,12 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxPipeline
.
from_pretrained
(
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
.
to
(
"cuda"
)
)
if
cpu_offload
:
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
pipeline
.
enable_sequential_cpu_offload
()
else
:
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
(
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
0
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
0
...
...
tests/flux/test_flux_schnell.py
View file @
2ede5f01
import
os
import
pytest
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku.utils
import
get_precision
,
is_turing
from
tests.data
import
get_dataset
from
.utils
import
run_test
from
tests.flux.utils
import
run_pipeline
from
tests.utils
import
already_generate
,
compute_lpips
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"
precision,
height,width,
num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size
,expected_lpips"
,
"height,width,
attention_impl,cpu_offload
,expected_lpips"
,
[
[
(
"int4"
,
1024
,
1024
,
4
,
0
,
False
,
False
,
16
,
0.25
8
),
(
1024
,
1024
,
"flashattn2"
,
False
,
0.25
0
),
(
"int4"
,
1024
,
1024
,
4
,
0
,
True
,
False
,
16
,
0.41
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.255
),
(
"int4"
,
1024
,
1024
,
4
,
0
,
True
,
False
,
16
,
0.41
),
(
1024
,
1024
,
"flashattn2"
,
True
,
0.250
),
(
"int4"
,
1920
,
1080
,
4
,
0
,
False
,
False
,
16
,
0.25
8
),
(
1920
,
1080
,
"nunchaku-fp16"
,
False
,
0.25
3
),
(
"int4"
,
600
,
800
,
4
,
0
,
False
,
False
,
16
,
0.2
9
),
(
2048
,
2048
,
"flashattn2"
,
True
,
0.2
74
),
],
],
)
)
def
test_flux_schnell
(
def
test_int4_schnell
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
precision
:
str
,
run_test
(
height
:
int
,
precision
=
get_precision
(),
width
:
int
,
height
=
height
,
num_inference_steps
:
int
,
width
=
width
,
guidance_scale
:
float
,
attention_impl
=
attention_impl
,
use_qencoder
:
bool
,
cpu_offload
=
cpu_offload
,
cpu_offload
:
bool
,
expected_lpips
=
expected_lpips
,
max_dataset_size
:
int
,
expected_lpips
:
float
,
):
dataset
=
get_dataset
(
name
=
"MJHQ"
,
max_dataset_size
=
max_dataset_size
)
save_root
=
os
.
path
.
join
(
"results"
,
"schnell"
,
f
"w
{
width
}
h
{
height
}
t
{
num_inference_steps
}
g
{
guidance_scale
}
"
)
save_dir_16bit
=
os
.
path
.
join
(
save_root
,
"bf16"
)
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_16bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
save_dir_4bit
=
os
.
path
.
join
(
save_root
,
f
"
{
precision
}
-qencoder"
if
use_qencoder
else
f
"
{
precision
}
"
+
(
"-cpuoffload"
if
cpu_offload
else
""
)
)
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{}
if
precision
==
"int4"
:
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
,
offload
=
cpu_offload
)
else
:
assert
precision
==
"fp4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-fp4-flux.1-schnell"
,
precision
=
"fp4"
,
offload
=
cpu_offload
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
use_qencoder
:
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_4bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.05
tests/flux/test_flux_tools.py
View file @
2ede5f01
import
pytest
import
torch
import
torch
from
controlnet_aux
import
CannyDetector
from
diffusers
import
FluxControlPipeline
,
FluxFillPipeline
,
FluxPipeline
,
FluxPriorReduxPipeline
from
diffusers.utils
import
load_image
from
image_gen_aux
import
DepthPreprocessor
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
def
test_flux_dev_canny
():
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-canny-dev"
)
def
test_flux_canny_dev
():
pipe
=
FluxControlPipeline
.
from_pretrained
(
run_test
(
"black-forest-labs/FLUX.1-Canny-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
precision
=
get_precision
(),
).
to
(
"cuda"
)
model_name
=
"flux.1-canny-dev"
,
dataset_name
=
"MJHQ-control"
,
prompt
=
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
# noqa: E501
task
=
"canny"
,
control_image
=
load_image
(
dtype
=
torch
.
bfloat16
,
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
height
=
1024
,
width
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
30
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cache_threshold
=
0
,
expected_lpips
=
0.103
if
get_precision
()
==
"int4"
else
0.164
,
)
)
processor
=
CannyDetector
()
control_image
=
processor
(
control_image
,
low_threshold
=
50
,
high_threshold
=
200
,
detect_resolution
=
1024
,
image_resolution
=
1024
)
image
=
pipe
(
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
prompt
=
prompt
,
control_image
=
control_image
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
30.0
def
test_flux_depth_dev
():
).
images
[
0
]
run_test
(
image
.
save
(
"flux.1-canny-dev.png"
)
precision
=
get_precision
(),
model_name
=
"flux.1-depth-dev"
,
dataset_name
=
"MJHQ-control"
,
task
=
"depth"
,
dtype
=
torch
.
bfloat16
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
30
,
guidance_scale
=
10
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cache_threshold
=
0
,
expected_lpips
=
0.103
if
get_precision
()
==
"int4"
else
0.120
,
)
def
test_flux_dev_depth
():
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-depth-dev"
)
def
test_flux_fill_dev
():
run_test
(
precision
=
get_precision
(),
model_name
=
"flux.1-fill-dev"
,
dataset_name
=
"MJHQ-control"
,
task
=
"fill"
,
dtype
=
torch
.
bfloat16
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
30
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cache_threshold
=
0
,
expected_lpips
=
0.045
,
)
pipe
=
FluxControlPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-Depth-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
"cuda"
)
prompt
=
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
# noqa: E501
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
control_image
=
load_image
(
def
test_flux_dev_canny_lora
():
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
run_test
(
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
dataset_name
=
"MJHQ-control"
,
task
=
"canny"
,
dtype
=
torch
.
bfloat16
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
30
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
lora_names
=
"canny"
,
lora_strengths
=
0.85
,
cache_threshold
=
0
,
expected_lpips
=
0.103
,
)
)
processor
=
DepthPreprocessor
.
from_pretrained
(
"LiheYoung/depth-anything-large-hf"
)
control_image
=
processor
(
control_image
)[
0
].
convert
(
"RGB"
)
image
=
pipe
(
prompt
=
prompt
,
control_image
=
control_image
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
30
,
guidance_scale
=
10.0
).
images
[
0
]
image
.
save
(
"flux.1-depth-dev.png"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
def
test_flux_dev_depth_lora
():
run_test
(
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
dataset_name
=
"MJHQ-control"
,
task
=
"depth"
,
dtype
=
torch
.
bfloat16
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
30
,
guidance_scale
=
10
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cache_threshold
=
0
,
lora_names
=
"depth"
,
lora_strengths
=
0.85
,
expected_lpips
=
0.163
,
)
def
test_flux_dev_fill
():
image
=
load_image
(
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png"
)
mask
=
load_image
(
"https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png"
)
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-fill-dev
"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs
"
)
pipe
=
F
lux
F
ill
Pipeline
.
from_pretrained
(
def
test_f
lux
_f
ill
_dev_turbo
():
"black-forest-labs/FLUX.1-Fill-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
run_test
(
).
to
(
"cuda"
)
precision
=
get_precision
(),
image
=
pipe
(
model_name
=
"flux.1-fill-dev"
,
prompt
=
"A wooden b
as
k
et
of a cat.
"
,
dat
aset
_name
=
"MJHQ-control
"
,
image
=
image
,
task
=
"fill"
,
mask_image
=
mask
,
dtype
=
torch
.
bfloat16
,
height
=
1024
,
height
=
1024
,
width
=
1024
,
width
=
1024
,
num_inference_steps
=
8
,
guidance_scale
=
30
,
guidance_scale
=
30
,
num_inference_steps
=
50
,
attention_impl
=
"nunchaku-fp16"
,
max_sequence_length
=
512
,
cpu_offload
=
False
,
).
images
[
0
]
cache_threshold
=
0
,
image
.
save
(
"flux.1-fill-dev.png"
)
lora_names
=
"turbo8"
,
lora_strengths
=
1
,
expected_lpips
=
0.048
,
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
def
test_flux_dev_redux
():
def
test_flux_dev_redux
():
pipe_prior_redux
=
FluxPriorReduxPipeline
.
from_pretrained
(
run_test
(
"black-forest-labs/FLUX.1-Redux-dev"
,
torch_dtype
=
torch
.
bfloat16
precision
=
get_precision
(),
).
to
(
"cuda"
)
model_name
=
"flux.1-dev"
,
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
)
dataset_name
=
"MJHQ-control"
,
pipe
=
FluxPipeline
.
from_pretrained
(
task
=
"redux"
,
"black-forest-labs/FLUX.1-dev"
,
dtype
=
torch
.
bfloat16
,
text_encoder
=
None
,
height
=
1024
,
text_encoder_2
=
None
,
width
=
1024
,
transformer
=
transformer
,
num_inference_steps
=
50
,
torch_dtype
=
torch
.
bfloat16
,
guidance_scale
=
2.5
,
).
to
(
"cuda"
)
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
image
=
load_image
(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
)
cache_threshold
=
0
,
pipe_prior_output
=
pipe_prior_redux
(
image
)
expected_lpips
=
0.187
if
get_precision
()
==
"int4"
else
0.55
,
# redux seems to generate different images on 5090
images
=
pipe
(
guidance_scale
=
2.5
,
num_inference_steps
=
50
,
**
pipe_prior_output
).
images
)
images
[
0
].
save
(
"flux.1-redux-dev.png"
)
tests/flux/test_shuttle_jaguar.py
View file @
2ede5f01
import
os
import
pytest
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
.utils
import
run_test
from
tests.data
import
get_dataset
from
nunchaku.utils
import
get_precision
,
is_turing
from
tests.flux.utils
import
run_pipeline
from
tests.utils
import
already_generate
,
compute_lpips
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips"
,
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[
[(
1024
,
1024
,
"flashattn2"
,
False
,
0.25
),
(
2048
,
512
,
"nunchaku-fp16"
,
False
,
0.25
)],
(
"int4"
,
1024
,
1024
,
4
,
3.5
,
False
,
False
,
16
,
0.25
),
(
"int4"
,
2048
,
512
,
4
,
3.5
,
False
,
False
,
16
,
0.21
),
],
)
)
def
test_shuttle_jaguar
(
def
test_shuttle_jaguar
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
precision
:
str
,
run_test
(
height
:
int
,
precision
=
get_precision
(),
width
:
int
,
model_name
=
"shuttle-jaguar"
,
num_inference_steps
:
int
,
height
=
height
,
guidance_scale
:
float
,
width
=
width
,
use_qencoder
:
bool
,
attention_impl
=
attention_impl
,
cpu_offload
:
bool
,
cpu_offload
=
cpu_offload
,
max_dataset_size
:
int
,
expected_lpips
=
expected_lpips
,
expected_lpips
:
float
,
):
dataset
=
get_dataset
(
name
=
"MJHQ"
,
max_dataset_size
=
max_dataset_size
)
save_root
=
os
.
path
.
join
(
"results"
,
"shuttle-jaguar"
,
f
"w
{
width
}
h
{
height
}
t
{
num_inference_steps
}
g
{
guidance_scale
}
"
)
save_dir_16bit
=
os
.
path
.
join
(
save_root
,
"bf16"
)
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
pipeline
=
FluxPipeline
.
from_pretrained
(
"shuttleai/shuttle-jaguar"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_16bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
save_dir_4bit
=
os
.
path
.
join
(
save_root
,
f
"
{
precision
}
-qencoder"
if
use_qencoder
else
f
"
{
precision
}
"
+
(
"-cpuoffload"
if
cpu_offload
else
""
)
)
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{}
if
precision
==
"int4"
:
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-shuttle-jaguar"
,
offload
=
cpu_offload
)
else
:
assert
precision
==
"fp4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-fp4-shuttle-jaguar"
,
precision
=
"fp4"
,
offload
=
cpu_offload
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
use_qencoder
:
raise
NotImplementedError
# text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
# pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline
=
FluxPipeline
.
from_pretrained
(
"shuttleai/shuttle-jaguar"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_4bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.05
tests/flux/test_turing.py
0 → 100644
View file @
2ede5f01
import
pytest
from
nunchaku.utils
import
get_precision
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
get_precision
()
==
"fp4"
,
reason
=
"Blackwell GPUs. Skip tests for Turing."
)
@
pytest
.
mark
.
parametrize
(
"height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips"
,
[
(
1024
,
1024
,
50
,
True
,
None
,
0.253
),
(
1024
,
1024
,
50
,
True
,
"enabled"
,
0.258
),
(
1024
,
1024
,
50
,
True
,
"always"
,
0.257
),
],
)
def
test_flux_dev
(
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
cpu_offload
:
bool
,
i2f_mode
:
str
|
None
,
expected_lpips
:
float
):
run_test
(
precision
=
get_precision
(),
dtype
=
"fp16"
,
model_name
=
"flux.1-dev"
,
height
=
height
,
width
=
width
,
num_inference_steps
=
num_inference_steps
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
cpu_offload
,
i2f_mode
=
i2f_mode
,
expected_lpips
=
expected_lpips
,
)
tests/flux/utils.py
View file @
2ede5f01
import
os
import
os
import
torch
import
torch
from
diffusers
import
FluxPipeline
from
controlnet_aux
import
CannyDetector
from
diffusers
import
FluxControlPipeline
,
FluxFillPipeline
,
FluxPipeline
,
FluxPriorReduxPipeline
from
diffusers.utils
import
load_image
from
image_gen_aux
import
DepthPreprocessor
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
..utils
import
hash_str_to_int
import
nunchaku
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku.lora.flux.compose
import
compose_lora
from
..data
import
download_hf_dataset
,
get_dataset
from
..utils
import
already_generate
,
compute_lpips
,
hash_str_to_int
ORIGINAL_REPO_MAP
=
{
"flux.1-schnell"
:
"black-forest-labs/FLUX.1-schnell"
,
"flux.1-dev"
:
"black-forest-labs/FLUX.1-dev"
,
"shuttle-jaguar"
:
"shuttleai/shuttle-jaguar"
,
"flux.1-canny-dev"
:
"black-forest-labs/FLUX.1-Canny-dev"
,
"flux.1-depth-dev"
:
"black-forest-labs/FLUX.1-Depth-dev"
,
"flux.1-fill-dev"
:
"black-forest-labs/FLUX.1-Fill-dev"
,
}
def
run_pipeline
(
dataset
,
pipeline
:
FluxPipeline
,
save_dir
:
str
,
forward_kwargs
:
dict
=
{}):
NUNCHAKU_REPO_PATTERN_MAP
=
{
"flux.1-schnell"
:
"mit-han-lab/svdq-{precision}-flux.1-schnell"
,
"flux.1-dev"
:
"mit-han-lab/svdq-{precision}-flux.1-dev"
,
"shuttle-jaguar"
:
"mit-han-lab/svdq-{precision}-shuttle-jaguar"
,
"flux.1-canny-dev"
:
"mit-han-lab/svdq-{precision}-flux.1-canny-dev"
,
"flux.1-depth-dev"
:
"mit-han-lab/svdq-{precision}-flux.1-depth-dev"
,
"flux.1-fill-dev"
:
"mit-han-lab/svdq-{precision}-flux.1-fill-dev"
,
}
LORA_PATH_MAP
=
{
"hypersd8"
:
"ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors"
,
"turbo8"
:
"alimama-creative/FLUX.1-Turbo-Alpha/diffusion_pytorch_model.safetensors"
,
"realism"
:
"XLabs-AI/flux-RealismLora/lora.safetensors"
,
"ghibsky"
:
"aleksa-codes/flux-ghibsky-illustration/lora.safetensors"
,
"anime"
:
"alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors"
,
"sketch"
:
"Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors"
,
"yarn"
:
"linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors"
,
"haunted_linework"
:
"alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors"
,
"canny"
:
"black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors"
,
"depth"
:
"black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors"
,
}
def
run_pipeline
(
dataset
,
task
:
str
,
pipeline
:
FluxPipeline
,
save_dir
:
str
,
forward_kwargs
:
dict
=
{}):
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
pipeline
.
set_progress_bar_config
(
desc
=
"Sampling"
,
leave
=
False
,
dynamic_ncols
=
True
,
position
=
1
)
pipeline
.
set_progress_bar_config
(
desc
=
"Sampling"
,
leave
=
False
,
dynamic_ncols
=
True
,
position
=
1
)
if
task
==
"canny"
:
processor
=
CannyDetector
()
elif
task
==
"depth"
:
processor
=
DepthPreprocessor
.
from_pretrained
(
"LiheYoung/depth-anything-large-hf"
)
elif
task
==
"redux"
:
processor
=
FluxPriorReduxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-Redux-dev"
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
else
:
assert
task
in
[
"t2i"
,
"fill"
]
processor
=
None
for
row
in
tqdm
(
dataset
):
for
row
in
tqdm
(
dataset
):
filename
=
row
[
"filename"
]
filename
=
row
[
"filename"
]
prompt
=
row
[
"prompt"
]
prompt
=
row
[
"prompt"
]
_forward_kwargs
=
{
k
:
v
for
k
,
v
in
forward_kwargs
.
items
()}
if
task
==
"canny"
:
assert
forward_kwargs
.
get
(
"height"
,
1024
)
==
1024
assert
forward_kwargs
.
get
(
"width"
,
1024
)
==
1024
control_image
=
load_image
(
row
[
"canny_image_path"
])
control_image
=
processor
(
control_image
,
low_threshold
=
50
,
high_threshold
=
200
,
detect_resolution
=
1024
,
image_resolution
=
1024
,
)
_forward_kwargs
[
"control_image"
]
=
control_image
elif
task
==
"depth"
:
control_image
=
load_image
(
row
[
"depth_image_path"
])
control_image
=
processor
(
control_image
)[
0
].
convert
(
"RGB"
)
_forward_kwargs
[
"control_image"
]
=
control_image
elif
task
==
"fill"
:
image
=
load_image
(
row
[
"image_path"
])
mask_image
=
load_image
(
row
[
"mask_image_path"
])
_forward_kwargs
[
"image"
]
=
image
_forward_kwargs
[
"mask_image"
]
=
mask_image
elif
task
==
"redux"
:
image
=
load_image
(
row
[
"image_path"
])
_forward_kwargs
.
update
(
processor
(
image
))
seed
=
hash_str_to_int
(
filename
)
seed
=
hash_str_to_int
(
filename
)
image
=
pipeline
(
prompt
,
generator
=
torch
.
Generator
().
manual_seed
(
seed
),
**
forward_kwargs
).
images
[
0
]
if
task
==
"redux"
:
image
=
pipeline
(
generator
=
torch
.
Generator
().
manual_seed
(
seed
),
**
_forward_kwargs
).
images
[
0
]
else
:
image
=
pipeline
(
prompt
,
generator
=
torch
.
Generator
().
manual_seed
(
seed
),
**
_forward_kwargs
).
images
[
0
]
image
.
save
(
os
.
path
.
join
(
save_dir
,
f
"
{
filename
}
.png"
))
image
.
save
(
os
.
path
.
join
(
save_dir
,
f
"
{
filename
}
.png"
))
torch
.
cuda
.
empty_cache
()
def
run_test
(
precision
:
str
=
"int4"
,
model_name
:
str
=
"flux.1-schnell"
,
dataset_name
:
str
=
"MJHQ"
,
task
:
str
=
"t2i"
,
dtype
:
str
|
torch
.
dtype
=
torch
.
bfloat16
,
# the full precision dtype
height
:
int
=
1024
,
width
:
int
=
1024
,
num_inference_steps
:
int
=
4
,
guidance_scale
:
float
=
3.5
,
use_qencoder
:
bool
=
False
,
attention_impl
:
str
=
"flashattn2"
,
# "flashattn2" or "nunchaku-fp16"
cpu_offload
:
bool
=
False
,
cache_threshold
:
float
=
0
,
lora_names
:
str
|
list
[
str
]
|
None
=
None
,
lora_strengths
:
float
|
list
[
float
]
=
1.0
,
max_dataset_size
:
int
=
20
,
i2f_mode
:
str
|
None
=
None
,
expected_lpips
:
float
=
0.5
,
):
if
isinstance
(
dtype
,
str
):
dtype_str
=
dtype
if
dtype
==
"bf16"
:
dtype
=
torch
.
bfloat16
else
:
assert
dtype
==
"fp16"
dtype
=
torch
.
float16
else
:
if
dtype
==
torch
.
bfloat16
:
dtype_str
=
"bf16"
else
:
assert
dtype
==
torch
.
float16
dtype_str
=
"fp16"
dataset
=
get_dataset
(
name
=
dataset_name
,
max_dataset_size
=
max_dataset_size
)
model_id_16bit
=
ORIGINAL_REPO_MAP
[
model_name
]
folder_name
=
f
"w
{
width
}
h
{
height
}
t
{
num_inference_steps
}
g
{
guidance_scale
}
"
if
lora_names
is
None
:
lora_names
=
[]
elif
isinstance
(
lora_names
,
str
):
lora_names
=
[
lora_names
]
if
len
(
lora_names
)
>
0
:
if
isinstance
(
lora_strengths
,
(
int
,
float
)):
lora_strengths
=
[
lora_strengths
]
assert
len
(
lora_names
)
==
len
(
lora_strengths
)
for
lora_name
,
lora_strength
in
zip
(
lora_names
,
lora_strengths
):
folder_name
+=
f
"-
{
lora_name
}
_
{
lora_strength
}
"
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
"test_results"
,
"ref"
)):
ref_root
=
download_hf_dataset
(
local_dir
=
os
.
path
.
join
(
"test_results"
,
"ref"
))
else
:
ref_root
=
os
.
path
.
join
(
"test_results"
,
"ref"
)
save_dir_16bit
=
os
.
path
.
join
(
ref_root
,
dtype_str
,
model_name
,
folder_name
)
if
task
in
[
"t2i"
,
"redux"
]:
pipeline_cls
=
FluxPipeline
elif
task
in
[
"canny"
,
"depth"
]:
pipeline_cls
=
FluxControlPipeline
elif
task
==
"fill"
:
pipeline_cls
=
FluxFillPipeline
else
:
raise
NotImplementedError
(
f
"Unknown task
{
task
}
!"
)
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{
"text_encoder"
:
None
,
"text_encoder2"
:
None
}
if
task
==
"redux"
else
{}
pipeline
=
pipeline_cls
.
from_pretrained
(
model_id_16bit
,
torch_dtype
=
dtype
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
if
len
(
lora_names
)
>
0
:
for
i
,
(
lora_name
,
lora_strength
)
in
enumerate
(
zip
(
lora_names
,
lora_strengths
)):
lora_path
=
LORA_PATH_MAP
[
lora_name
]
pipeline
.
load_lora_weights
(
os
.
path
.
dirname
(
lora_path
),
weight_name
=
os
.
path
.
basename
(
lora_path
),
adapter_name
=
f
"lora_
{
i
}
"
)
pipeline
.
set_adapters
([
f
"lora_
{
i
}
"
for
i
in
range
(
len
(
lora_names
))],
lora_strengths
)
run_pipeline
(
dataset
=
dataset
,
task
=
task
,
pipeline
=
pipeline
,
save_dir
=
save_dir_16bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
precision_str
=
precision
if
use_qencoder
:
precision_str
+=
"-qe"
if
attention_impl
==
"flashattn2"
:
precision_str
+=
"-fa2"
else
:
assert
attention_impl
==
"nunchaku-fp16"
precision_str
+=
"-nfp16"
if
cpu_offload
:
precision_str
+=
"-co"
if
cache_threshold
>
0
:
precision_str
+=
f
"-cache
{
cache_threshold
}
"
if
i2f_mode
is
not
None
:
precision_str
+=
f
"-i2f
{
i2f_mode
}
"
save_dir_4bit
=
os
.
path
.
join
(
"test_results"
,
dtype_str
,
precision_str
,
model_name
,
folder_name
)
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{}
model_id_4bit
=
NUNCHAKU_REPO_PATTERN_MAP
[
model_name
].
format
(
precision
=
precision
)
if
i2f_mode
is
not
None
:
nunchaku
.
_C
.
utils
.
set_faster_i2f_mode
(
i2f_mode
)
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
model_id_4bit
,
offload
=
cpu_offload
,
torch_dtype
=
dtype
)
transformer
.
set_attention_impl
(
attention_impl
)
if
len
(
lora_names
)
>
0
:
if
len
(
lora_names
)
==
1
:
# directly load the lora
lora_path
=
LORA_PATH_MAP
[
lora_names
[
0
]]
lora_strength
=
lora_strengths
[
0
]
transformer
.
update_lora_params
(
lora_path
)
transformer
.
set_lora_strength
(
lora_strength
)
else
:
composed_lora
=
compose_lora
(
[
(
LORA_PATH_MAP
[
lora_name
],
lora_strength
)
for
lora_name
,
lora_strength
in
zip
(
lora_names
,
lora_strengths
)
]
)
transformer
.
update_lora_params
(
composed_lora
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
task
==
"redux"
:
pipeline_init_kwargs
.
update
({
"text_encoder"
:
None
,
"text_encoder_2"
:
None
})
elif
use_qencoder
:
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
pipeline_cls
.
from_pretrained
(
model_id_16bit
,
torch_dtype
=
dtype
,
**
pipeline_init_kwargs
)
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
else
:
pipeline
=
pipeline
.
to
(
"cuda"
)
run_pipeline
(
dataset
=
dataset
,
task
=
task
,
pipeline
=
pipeline
,
save_dir
=
save_dir_4bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
transformer
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.05
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment