Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
0b1891cd
Commit
0b1891cd
authored
Mar 10, 2025
by
muyangli
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
[feat] add first block cache
parent
39f90121
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
540 additions
and
77 deletions
+540
-77
examples/int4-flux.1-dev-cache.py
examples/int4-flux.1-dev-cache.py
+14
-0
nunchaku/__version__.py
nunchaku/__version__.py
+1
-1
nunchaku/caching/diffusers_adapters/__init__.py
nunchaku/caching/diffusers_adapters/__init__.py
+2
-23
nunchaku/caching/diffusers_adapters/flux.py
nunchaku/caching/diffusers_adapters/flux.py
+8
-30
nunchaku/caching/utils.py
nunchaku/caching/utils.py
+25
-12
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+9
-2
nunchaku/test.py
nunchaku/test.py
+4
-2
scripts/build_windows_wheels.ps1
scripts/build_windows_wheels.ps1
+32
-0
src/FluxModel.cpp
src/FluxModel.cpp
+7
-7
tests/flux/test_flux_cache.py
tests/flux/test_flux_cache.py
+49
-0
tests/flux/test_flux_dev.py
tests/flux/test_flux_dev.py
+185
-0
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+48
-0
tests/flux/test_flux_memory.py
tests/flux/test_flux_memory.py
+41
-0
tests/flux/test_flux_schnell.py
tests/flux/test_flux_schnell.py
+97
-0
tests/flux/utils.py
tests/flux/utils.py
+18
-0
No files found.
examples/
flux-dyn
-cach
ing
.py
→
examples/
int4-flux.1-dev
-cach
e
.py
View file @
0b1891cd
import
torch
import
torch
from
diffusers
import
FluxPipeline
from
diffusers
import
FluxPipeline
from
nunchaku
.models.transformer_flux
import
NunchakuFluxTransformer2dModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
import
time
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
,
offload
=
True
)
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
)
pipeline
.
enable_sequential_cpu_offload
()
apply_cache_on_pipe
(
apply_cache_on_pipe
(
pipeline
,
residual_diff_threshold
=
0.12
)
pipeline
,
residual_diff_threshold
=
0.12
)
image
=
pipeline
([
"A cat holding a sign that says hello world"
],
num_inference_steps
=
50
).
images
[
0
]
image
=
pipeline
(
image
.
save
(
"flux.1-dev-int4.png"
)
[
"A cat holding a sign that says hello world"
],
width
=
1024
,
height
=
1024
,
num_inference_steps
=
32
,
guidance_scale
=
0
).
images
[
0
]
image
.
save
(
"flux.1-schnell-int4-0.12.png"
)
nunchaku/__version__.py
View file @
0b1891cd
__version__
=
"0.1.
4
"
__version__
=
"0.1.
5
"
nunchaku/caching/diffusers_adapters/__init__.py
View file @
0b1891cd
import
importlib
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
def
apply_cache_on_transformer
(
transformer
,
*
args
,
**
kwargs
):
transformer_cls_name
=
transformer
.
__class__
.
__name__
if
False
:
pass
elif
transformer_cls_name
.
startswith
(
"Flux"
):
adapter_name
=
"flux"
else
:
raise
ValueError
(
f
"Unknown transformer class name:
{
transformer_cls_name
}
"
)
adapter_module
=
importlib
.
import_module
(
f
".
{
adapter_name
}
"
,
__package__
)
apply_cache_on_transformer_fn
=
getattr
(
adapter_module
,
"apply_cache_on_transformer"
)
return
apply_cache_on_transformer_fn
(
transformer
,
*
args
,
**
kwargs
)
def
apply_cache_on_pipe
(
pipe
:
DiffusionPipeline
,
*
args
,
**
kwargs
):
def
apply_cache_on_pipe
(
pipe
:
DiffusionPipeline
,
*
args
,
**
kwargs
):
assert
isinstance
(
pipe
,
DiffusionPipeline
)
assert
isinstance
(
pipe
,
DiffusionPipeline
)
pipe_cls_name
=
pipe
.
__class__
.
__name__
pipe_cls_name
=
pipe
.
__class__
.
__name__
if
False
:
if
pipe_cls_name
.
startswith
(
"Flux"
):
pass
from
.flux
import
apply_cache_on_pipe
as
apply_cache_on_pipe_fn
elif
pipe_cls_name
.
startswith
(
"Flux"
):
adapter_name
=
"flux"
else
:
else
:
raise
ValueError
(
f
"Unknown pipeline class name:
{
pipe_cls_name
}
"
)
raise
ValueError
(
f
"Unknown pipeline class name:
{
pipe_cls_name
}
"
)
print
(
"Registering Flux"
)
adapter_module
=
importlib
.
import_module
(
f
".
{
adapter_name
}
"
,
__package__
)
apply_cache_on_pipe_fn
=
getattr
(
adapter_module
,
"apply_cache_on_pipe"
)
return
apply_cache_on_pipe_fn
(
pipe
,
*
args
,
**
kwargs
)
return
apply_cache_on_pipe_fn
(
pipe
,
*
args
,
**
kwargs
)
nunchaku/caching/diffusers_adapters/flux.py
View file @
0b1891cd
...
@@ -4,14 +4,10 @@ import unittest
...
@@ -4,14 +4,10 @@ import unittest
import
torch
import
torch
from
diffusers
import
DiffusionPipeline
,
FluxTransformer2DModel
from
diffusers
import
DiffusionPipeline
,
FluxTransformer2DModel
from
nunchaku
.caching
import
utils
from
..
.caching
import
utils
def
apply_cache_on_transformer
(
def
apply_cache_on_transformer
(
transformer
:
FluxTransformer2DModel
,
*
,
residual_diff_threshold
=
0.12
):
transformer
:
FluxTransformer2DModel
,
*
,
residual_diff_threshold
=
0.05
,
):
if
getattr
(
transformer
,
"_is_cached"
,
False
):
if
getattr
(
transformer
,
"_is_cached"
,
False
):
return
transformer
return
transformer
...
@@ -29,38 +25,20 @@ def apply_cache_on_transformer(
...
@@ -29,38 +25,20 @@ def apply_cache_on_transformer(
original_forward
=
transformer
.
forward
original_forward
=
transformer
.
forward
@
functools
.
wraps
(
original_forward
)
@
functools
.
wraps
(
original_forward
)
def
new_forward
(
def
new_forward
(
self
,
*
args
,
**
kwargs
):
self
,
with
(
*
args
,
unittest
.
mock
.
patch
.
object
(
self
,
"transformer_blocks"
,
cached_transformer_blocks
),
**
kwargs
,
unittest
.
mock
.
patch
.
object
(
self
,
"single_transformer_blocks"
,
dummy_single_transformer_blocks
),
):
with
unittest
.
mock
.
patch
.
object
(
self
,
"transformer_blocks"
,
cached_transformer_blocks
,
),
unittest
.
mock
.
patch
.
object
(
self
,
"single_transformer_blocks"
,
dummy_single_transformer_blocks
,
):
):
return
original_forward
(
return
original_forward
(
*
args
,
**
kwargs
)
*
args
,
**
kwargs
,
)
transformer
.
forward
=
new_forward
.
__get__
(
transformer
)
transformer
.
forward
=
new_forward
.
__get__
(
transformer
)
transformer
.
_is_cached
=
True
transformer
.
_is_cached
=
True
return
transformer
return
transformer
def
apply_cache_on_pipe
(
def
apply_cache_on_pipe
(
pipe
:
DiffusionPipeline
,
*
,
shallow_patch
:
bool
=
False
,
**
kwargs
):
pipe
:
DiffusionPipeline
,
*
,
shallow_patch
:
bool
=
False
,
**
kwargs
,
):
if
not
getattr
(
pipe
,
"_is_cached"
,
False
):
if
not
getattr
(
pipe
,
"_is_cached"
,
False
):
original_call
=
pipe
.
__class__
.
__call__
original_call
=
pipe
.
__class__
.
__call__
...
...
nunchaku/caching/utils.py
View file @
0b1891cd
# This cach
a
ing functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
# This caching functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
import
contextlib
import
contextlib
import
dataclasses
import
dataclasses
...
@@ -6,6 +6,7 @@ from collections import defaultdict
...
@@ -6,6 +6,7 @@ from collections import defaultdict
from
typing
import
DefaultDict
,
Dict
from
typing
import
DefaultDict
,
Dict
import
torch
import
torch
from
torch
import
nn
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -34,7 +35,6 @@ class CacheContext:
...
@@ -34,7 +35,6 @@ class CacheContext:
self
.
buffers
.
clear
()
self
.
buffers
.
clear
()
@
torch
.
compiler
.
disable
@
torch
.
compiler
.
disable
def
get_buffer
(
name
):
def
get_buffer
(
name
):
cache_context
=
get_current_cache_context
()
cache_context
=
get_current_cache_context
()
...
@@ -49,7 +49,6 @@ def set_buffer(name, buffer):
...
@@ -49,7 +49,6 @@ def set_buffer(name, buffer):
cache_context
.
set_buffer
(
name
,
buffer
)
cache_context
.
set_buffer
(
name
,
buffer
)
_current_cache_context
=
None
_current_cache_context
=
None
...
@@ -79,8 +78,11 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
...
@@ -79,8 +78,11 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
diff
=
mean_diff
/
mean_t1
diff
=
mean_diff
/
mean_t1
return
diff
.
item
()
<
threshold
return
diff
.
item
()
<
threshold
@
torch
.
compiler
.
disable
@
torch
.
compiler
.
disable
def
apply_prev_hidden_states_residual
(
hidden_states
,
encoder_hidden_states
):
def
apply_prev_hidden_states_residual
(
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states_residual
=
get_buffer
(
"hidden_states_residual"
)
hidden_states_residual
=
get_buffer
(
"hidden_states_residual"
)
assert
hidden_states_residual
is
not
None
,
"hidden_states_residual must be set before"
assert
hidden_states_residual
is
not
None
,
"hidden_states_residual must be set before"
hidden_states
=
hidden_states_residual
+
hidden_states
hidden_states
=
hidden_states_residual
+
hidden_states
...
@@ -94,6 +96,7 @@ def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
...
@@ -94,6 +96,7 @@ def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
return
hidden_states
,
encoder_hidden_states
return
hidden_states
,
encoder_hidden_states
@
torch
.
compiler
.
disable
@
torch
.
compiler
.
disable
def
get_can_use_cache
(
first_hidden_states_residual
,
threshold
,
parallelized
=
False
):
def
get_can_use_cache
(
first_hidden_states_residual
,
threshold
,
parallelized
=
False
):
prev_first_hidden_states_residual
=
get_buffer
(
"first_hidden_states_residual"
)
prev_first_hidden_states_residual
=
get_buffer
(
"first_hidden_states_residual"
)
...
@@ -105,7 +108,8 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
...
@@ -105,7 +108,8 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
)
)
return
can_use_cache
return
can_use_cache
class
CachedTransformerBlocks
(
torch
.
nn
.
Module
):
class
CachedTransformerBlocks
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
@@ -113,6 +117,7 @@ class CachedTransformerBlocks(torch.nn.Module):
...
@@ -113,6 +117,7 @@ class CachedTransformerBlocks(torch.nn.Module):
residual_diff_threshold
,
residual_diff_threshold
,
return_hidden_states_first
=
True
,
return_hidden_states_first
=
True
,
return_hidden_states_only
=
False
,
return_hidden_states_only
=
False
,
verbose
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
transformer
=
transformer
self
.
transformer
=
transformer
...
@@ -121,6 +126,7 @@ class CachedTransformerBlocks(torch.nn.Module):
...
@@ -121,6 +126,7 @@ class CachedTransformerBlocks(torch.nn.Module):
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
return_hidden_states_first
=
return_hidden_states_first
self
.
return_hidden_states_first
=
return_hidden_states_first
self
.
return_hidden_states_only
=
return_hidden_states_only
self
.
return_hidden_states_only
=
return_hidden_states_only
self
.
verbose
=
verbose
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
):
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
...
@@ -130,7 +136,8 @@ class CachedTransformerBlocks(torch.nn.Module):
...
@@ -130,7 +136,8 @@ class CachedTransformerBlocks(torch.nn.Module):
first_transformer_block
=
self
.
transformer_blocks
[
0
]
first_transformer_block
=
self
.
transformer_blocks
[
0
]
encoder_hidden_states
,
hidden_states
=
first_transformer_block
(
encoder_hidden_states
,
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
*
args
,
**
kwargs
)
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
*
args
,
**
kwargs
)
return
(
return
(
hidden_states
hidden_states
...
@@ -145,7 +152,8 @@ class CachedTransformerBlocks(torch.nn.Module):
...
@@ -145,7 +152,8 @@ class CachedTransformerBlocks(torch.nn.Module):
original_hidden_states
=
hidden_states
original_hidden_states
=
hidden_states
first_transformer_block
=
self
.
transformer_blocks
[
0
]
first_transformer_block
=
self
.
transformer_blocks
[
0
]
encoder_hidden_states
,
hidden_states
=
first_transformer_block
.
forward_layer_at
(
encoder_hidden_states
,
hidden_states
=
first_transformer_block
.
forward_layer_at
(
0
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
)
0
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
)
first_hidden_states_residual
=
hidden_states
-
original_hidden_states
first_hidden_states_residual
=
hidden_states
-
original_hidden_states
del
original_hidden_states
del
original_hidden_states
...
@@ -159,12 +167,14 @@ class CachedTransformerBlocks(torch.nn.Module):
...
@@ -159,12 +167,14 @@ class CachedTransformerBlocks(torch.nn.Module):
torch
.
_dynamo
.
graph_break
()
torch
.
_dynamo
.
graph_break
()
if
can_use_cache
:
if
can_use_cache
:
del
first_hidden_states_residual
del
first_hidden_states_residual
print
(
"Cache hit!!!"
)
if
self
.
verbose
:
print
(
"Cache hit!!!"
)
hidden_states
,
encoder_hidden_states
=
apply_prev_hidden_states_residual
(
hidden_states
,
encoder_hidden_states
=
apply_prev_hidden_states_residual
(
hidden_states
,
encoder_hidden_states
hidden_states
,
encoder_hidden_states
)
)
else
:
else
:
print
(
"Cache miss!!!"
)
if
self
.
verbose
:
print
(
"Cache miss!!!"
)
set_buffer
(
"first_hidden_states_residual"
,
first_hidden_states_residual
)
set_buffer
(
"first_hidden_states_residual"
,
first_hidden_states_residual
)
del
first_hidden_states_residual
del
first_hidden_states_residual
(
(
...
@@ -192,9 +202,12 @@ class CachedTransformerBlocks(torch.nn.Module):
...
@@ -192,9 +202,12 @@ class CachedTransformerBlocks(torch.nn.Module):
original_hidden_states
=
hidden_states
original_hidden_states
=
hidden_states
original_encoder_hidden_states
=
encoder_hidden_states
original_encoder_hidden_states
=
encoder_hidden_states
encoder_hidden_states
,
hidden_states
=
first_transformer_block
.
forward
(
encoder_hidden_states
,
hidden_states
=
first_transformer_block
.
forward
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
skip_first_layer
=
True
,
*
args
,
**
kwargs
)
skip_first_layer
=
True
,
*
args
,
**
kwargs
,
)
hidden_states
=
hidden_states
.
contiguous
()
hidden_states
=
hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
0b1891cd
...
@@ -59,7 +59,13 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -59,7 +59,13 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_single
=
pad_tensor
(
rotary_emb_single
,
256
,
1
)
rotary_emb_single
=
pad_tensor
(
rotary_emb_single
,
256
,
1
)
hidden_states
=
self
.
m
.
forward
(
hidden_states
=
self
.
m
.
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_single
,
skip_first_layer
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_single
,
skip_first_layer
,
)
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
...
@@ -103,7 +109,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -103,7 +109,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img
=
pad_tensor
(
rotary_emb_img
,
256
,
1
)
rotary_emb_img
=
pad_tensor
(
rotary_emb_img
,
256
,
1
)
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
0
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
)
idx
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
...
...
nunchaku/test.py
View file @
0b1891cd
...
@@ -9,11 +9,13 @@ if __name__ == "__main__":
...
@@ -9,11 +9,13 @@ if __name__ == "__main__":
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
precision
=
precision
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
offload
=
True
,
precision
=
precision
)
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
)
image
=
pipeline
(
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
).
images
[
0
]
image
.
save
(
"flux.1-schnell.png"
)
scripts/build_windows_wheels.ps1
0 → 100644
View file @
0b1891cd
param
(
[
string
]
$PYTHON_VERSION
,
[
string
]
$TORCH_VERSION
,
[
string
]
$CUDA_VERSION
,
[
string
]
$MAX_JOBS
=
""
)
# Conda 环境名称
$ENV_NAME
=
"build_env_
$PYTHON_VERSION
"
# 创建 Conda 环境
conda
create
-y
-n
$ENV_NAME
python
=
$PYTHON_VERSION
conda
activate
$ENV_NAME
# 安装依赖
conda
install
-y
ninja
setuptools
wheel
pip
pip
install
--no-cache-dir
torch
==
$TORCH_VERSION
numpy
--index-url
"https://download.pytorch.org/whl/cu
$(
$CUDA_VERSION
.
Substring
(
0
,
2
)
)
/"
# 设置环境变量
$
env
:
NUNCHAKU_INSTALL_MODE
=
"ALL"
$
env
:
NUNCHAKU_BUILD_WHEELS
=
"1"
$
env
:
MAX_JOBS
=
$MAX_JOBS
# 进入当前脚本所在目录并构建 wheels
Set-Location
-Path
$PSScriptRoot
if
(
Test-Path
"build"
)
{
Remove-Item
-Recurse
-Force
"build"
}
python
-m
build
--wheel
--no-isolation
# 退出 Conda 环境
conda
deactivate
Write-Output
"Build complete!"
src/FluxModel.cpp
View file @
0b1891cd
...
@@ -485,16 +485,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -485,16 +485,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
}
else
{
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
raw_attn_output_split
.
data_ptr
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
(),
raw_attn_output
.
data_ptr
(),
(
num_tokens_img
+
num_tokens_context
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
(
num_tokens_img
+
num_tokens_context
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
batch_size
,
cudaMemcpyDeviceToDevice
,
cudaMemcpyDeviceToDevice
,
stream
));
stream
));
}
}
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
debug
(
"img.raw_attn_output_split"
,
raw_attn_output_split
);
debug
(
"img.raw_attn_output_split"
,
raw_attn_output_split
);
...
@@ -550,16 +550,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -550,16 +550,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
}
else
{
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_context
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_context
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
raw_attn_output_split
.
data_ptr
(),
num_tokens_context
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_context
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
(
num_tokens_img
+
num_tokens_context
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
(
num_tokens_img
+
num_tokens_context
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_context
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_context
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
batch_size
,
cudaMemcpyDeviceToDevice
,
cudaMemcpyDeviceToDevice
,
stream
));
stream
));
}
}
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
debug
(
"context.raw_attn_output_split"
,
raw_attn_output_split
);
debug
(
"context.raw_attn_output_split"
,
raw_attn_output_split
);
...
@@ -585,7 +585,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -585,7 +585,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
#else
#else
auto
norm_hidden_states
=
encoder_hidden_states
;
auto
norm_hidden_states
=
encoder_hidden_states
;
#endif
#endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
...
...
tests/flux/test_flux_cache.py
0 → 100644
View file @
0b1891cd
import
pytest
from
.test_flux_dev
import
run_test_flux_dev
@
pytest
.
mark
.
parametrize
(
"height,width,num_inference_steps,cache_threshold,lora_name,use_qencoder,cpu_offload,expected_lpips"
,
[
# (1024, 1024, 50, 0, None, False, False, 0.5), # 13min20s 5min55s 0.19539418816566467
# (1024, 1024, 50, 0.05, None, False, True, 0.5), # 7min11s 0.21917256712913513
# (1024, 1024, 50, 0.12, None, False, True, 0.5), # 2min58s, 0.24101486802101135
# (1024, 1024, 50, 0.2, None, False, True, 0.5), # 2min23s, 0.3101634383201599
# (1024, 1024, 50, 0.5, None, False, True, 0.5), # 1min44s 0.6543852090835571
# (1024, 1024, 30, 0, None, False, False, 0.5), # 8min2s 3min40s 0.2141970843076706
# (1024, 1024, 30, 0.05, None, False, True, 0.5), # 4min57 0.21297718584537506
# (1024, 1024, 30, 0.12, None, False, True, 0.5), # 2min34 0.25963714718818665
# (1024, 1024, 30, 0.2, None, False, True, 0.5), # 1min51 0.31409069895744324
# (1024, 1024, 20, 0, None, False, False, 0.5), # 5min25 2min29 0.18987375497817993
# (1024, 1024, 20, 0.05, None, False, True, 0.5), # 3min3 0.17194810509681702
# (1024, 1024, 20, 0.12, None, False, True, 0.5), # 2min15 0.19407868385314941
# (1024, 1024, 20, 0.2, None, False, True, 0.5), # 1min48 0.2832985818386078
(
1024
,
1024
,
30
,
0.12
,
None
,
False
,
False
,
0.26
),
(
512
,
2048
,
30
,
0.12
,
"anime"
,
True
,
False
,
0.4
),
],
)
def
test_flux_dev_base
(
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
cache_threshold
:
float
,
lora_name
:
str
|
None
,
use_qencoder
:
bool
,
cpu_offload
:
bool
,
expected_lpips
:
float
,
):
run_test_flux_dev
(
precision
=
"int4"
,
height
=
height
,
width
=
width
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
3.5
,
use_qencoder
=
use_qencoder
,
cpu_offload
=
cpu_offload
,
lora_name
=
lora_name
,
lora_scale
=
1
,
cache_threshold
=
cache_threshold
,
max_dataset_size
=
16
,
expected_lpips
=
expected_lpips
,
)
tests/flux/test_
t2i
.py
→
tests/flux/test_
flux_dev
.py
View file @
0b1891cd
...
@@ -6,111 +6,13 @@ import torch
...
@@ -6,111 +6,13 @@ import torch
from
diffusers
import
FluxPipeline
from
diffusers
import
FluxPipeline
from
peft.tuners
import
lora
from
peft.tuners
import
lora
from
safetensors.torch
import
save_file
from
safetensors.torch
import
save_file
from
tqdm
import
tqdm
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
from
nunchaku.lora.flux
import
comfyui2diffusers
,
convert_to_nunchaku_flux_lowrank_dict
,
detect_format
,
xlab2diffusers
from
nunchaku.lora.flux
import
comfyui2diffusers
,
convert_to_nunchaku_flux_lowrank_dict
,
detect_format
,
xlab2diffusers
from
.utils
import
run_pipeline
from
..data
import
get_dataset
from
..data
import
get_dataset
from
..utils
import
already_generate
,
compute_lpips
,
hash_str_to_int
from
..utils
import
already_generate
,
compute_lpips
def
run_pipeline
(
dataset
,
pipeline
:
FluxPipeline
,
save_dir
:
str
,
forward_kwargs
:
dict
=
{}):
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
pipeline
.
set_progress_bar_config
(
desc
=
"Sampling"
,
leave
=
False
,
dynamic_ncols
=
True
,
position
=
1
)
for
row
in
tqdm
(
dataset
):
filename
=
row
[
"filename"
]
prompt
=
row
[
"prompt"
]
seed
=
hash_str_to_int
(
filename
)
image
=
pipeline
(
prompt
,
generator
=
torch
.
Generator
().
manual_seed
(
seed
),
**
forward_kwargs
).
images
[
0
]
image
.
save
(
os
.
path
.
join
(
save_dir
,
f
"
{
filename
}
.png"
))
@
pytest
.
mark
.
parametrize
(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips"
,
[
(
"int4"
,
1024
,
1024
,
4
,
0
,
False
,
False
,
16
,
0.258
),
(
"int4"
,
1024
,
1024
,
4
,
0
,
True
,
False
,
16
,
0.41
),
(
"int4"
,
1024
,
1024
,
4
,
0
,
True
,
False
,
16
,
0.41
),
(
"int4"
,
1920
,
1080
,
4
,
0
,
False
,
False
,
16
,
0.258
),
(
"int4"
,
600
,
800
,
4
,
0
,
False
,
False
,
16
,
0.29
),
],
)
def
test_flux_schnell
(
precision
:
str
,
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
use_qencoder
:
bool
,
cpu_offload
:
bool
,
max_dataset_size
:
int
,
expected_lpips
:
float
,
):
dataset
=
get_dataset
(
name
=
"MJHQ"
,
max_dataset_size
=
max_dataset_size
)
save_root
=
os
.
path
.
join
(
"results"
,
"schnell"
,
f
"w
{
width
}
h
{
height
}
t
{
num_inference_steps
}
g
{
guidance_scale
}
"
)
save_dir_16bit
=
os
.
path
.
join
(
save_root
,
"bf16"
)
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_16bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
save_dir_4bit
=
os
.
path
.
join
(
save_root
,
f
"
{
precision
}
-qencoder"
if
use_qencoder
else
f
"
{
precision
}
"
+
(
"-cpuoffload"
if
cpu_offload
else
""
)
)
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{}
if
precision
==
"int4"
:
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
,
offload
=
cpu_offload
)
else
:
assert
precision
==
"fp4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-fp4-flux.1-schnell"
,
precision
=
"fp4"
,
offload
=
cpu_offload
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
use_qencoder
:
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_4bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.05
LORA_PATH_MAP
=
{
LORA_PATH_MAP
=
{
"hypersd8"
:
"ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors"
,
"hypersd8"
:
"ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors"
,
...
@@ -133,6 +35,7 @@ def run_test_flux_dev(
...
@@ -133,6 +35,7 @@ def run_test_flux_dev(
cpu_offload
:
bool
,
cpu_offload
:
bool
,
lora_name
:
str
|
None
,
lora_name
:
str
|
None
,
lora_scale
:
float
,
lora_scale
:
float
,
cache_threshold
:
float
,
max_dataset_size
:
int
,
max_dataset_size
:
int
,
expected_lpips
:
float
,
expected_lpips
:
float
,
):
):
...
@@ -140,7 +43,6 @@ def run_test_flux_dev(
...
@@ -140,7 +43,6 @@ def run_test_flux_dev(
"results"
,
"results"
,
"dev"
,
"dev"
,
f
"w
{
width
}
h
{
height
}
t
{
num_inference_steps
}
g
{
guidance_scale
}
"
f
"w
{
width
}
h
{
height
}
t
{
num_inference_steps
}
g
{
guidance_scale
}
"
+
(
"-qencoder"
if
use_qencoder
else
""
)
+
(
f
"-
{
lora_name
}
_
{
lora_scale
:.
1
f
}
"
if
lora_name
else
""
),
+
(
f
"-
{
lora_name
}
_
{
lora_scale
:.
1
f
}
"
if
lora_name
else
""
),
)
)
dataset
=
get_dataset
(
dataset
=
get_dataset
(
...
@@ -177,7 +79,12 @@ def run_test_flux_dev(
...
@@ -177,7 +79,12 @@ def run_test_flux_dev(
# release the gpu memory
# release the gpu memory
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
save_dir_4bit
=
os
.
path
.
join
(
save_root
,
f
"
{
precision
}
-qencoder"
if
use_qencoder
else
f
"
{
precision
}
"
)
name
=
precision
name
+=
"-qencoder"
if
use_qencoder
else
""
name
+=
"-offload"
if
cpu_offload
else
""
name
+=
f
"-cache
{
cache_threshold
:.
2
f
}
"
if
cache_threshold
>
0
else
""
save_dir_4bit
=
os
.
path
.
join
(
save_root
,
name
)
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{}
pipeline_init_kwargs
=
{}
if
precision
==
"int4"
:
if
precision
==
"int4"
:
...
@@ -221,6 +128,9 @@ def run_test_flux_dev(
...
@@ -221,6 +128,9 @@ def run_test_flux_dev(
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
=
pipeline
.
to
(
"cuda"
)
if
cpu_offload
:
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
pipeline
.
enable_sequential_cpu_offload
()
if
cache_threshold
>
0
:
apply_cache_on_pipe
(
pipeline
,
residual_diff_threshold
=
cache_threshold
)
run_pipeline
(
run_pipeline
(
dataset
,
dataset
,
pipeline
,
pipeline
,
...
@@ -252,6 +162,7 @@ def test_flux_dev_base(cpu_offload: bool):
...
@@ -252,6 +162,7 @@ def test_flux_dev_base(cpu_offload: bool):
cpu_offload
=
cpu_offload
,
cpu_offload
=
cpu_offload
,
lora_name
=
None
,
lora_name
=
None
,
lora_scale
=
0
,
lora_scale
=
0
,
cache_threshold
=
0
,
max_dataset_size
=
8
,
max_dataset_size
=
8
,
expected_lpips
=
0.16
,
expected_lpips
=
0.16
,
)
)
...
@@ -268,85 +179,7 @@ def test_flux_dev_qencoder_800x600():
...
@@ -268,85 +179,7 @@ def test_flux_dev_qencoder_800x600():
cpu_offload
=
False
,
cpu_offload
=
False
,
lora_name
=
None
,
lora_name
=
None
,
lora_scale
=
0
,
lora_scale
=
0
,
cache_threshold
=
0
,
max_dataset_size
=
8
,
max_dataset_size
=
8
,
expected_lpips
=
0.36
,
expected_lpips
=
0.36
,
)
)
def
test_flux_dev_hypersd8_1080x1920
():
run_test_flux_dev
(
precision
=
"int4"
,
height
=
1080
,
width
=
1920
,
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
cpu_offload
=
False
,
lora_name
=
"hypersd8"
,
lora_scale
=
0.125
,
max_dataset_size
=
8
,
expected_lpips
=
0.44
,
)
@
pytest
.
mark
.
parametrize
(
"num_inference_steps,lora_name,lora_scale,cpu_offload,expected_lpips"
,
[
(
25
,
"realism"
,
0.9
,
False
,
0.16
),
(
25
,
"ghibsky"
,
1
,
False
,
0.16
),
(
28
,
"anime"
,
1
,
False
,
0.27
),
(
24
,
"sketch"
,
1
,
False
,
0.35
),
(
28
,
"yarn"
,
1
,
False
,
0.22
),
(
25
,
"haunted_linework"
,
1
,
False
,
0.34
),
],
)
def
test_flux_dev_loras
(
num_inference_steps
,
lora_name
,
lora_scale
,
cpu_offload
,
expected_lpips
):
run_test_flux_dev
(
precision
=
"int4"
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
cpu_offload
=
cpu_offload
,
lora_name
=
lora_name
,
lora_scale
=
lora_scale
,
max_dataset_size
=
8
,
expected_lpips
=
expected_lpips
,
)
@
pytest
.
mark
.
parametrize
(
"use_qencoder,cpu_offload,memory_limit"
,
[
(
False
,
False
,
17
),
(
False
,
True
,
13
),
(
True
,
False
,
12
),
(
True
,
True
,
6
),
],
)
def
test_flux_schnell_memory
(
use_qencoder
:
bool
,
cpu_offload
:
bool
,
memory_limit
:
float
):
torch
.
cuda
.
reset_peak_memory_stats
()
pipeline_init_kwargs
=
{
"transformer"
:
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
,
offload
=
cpu_offload
)
}
if
use_qencoder
:
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
).
to
(
"cuda"
)
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
0
)
memory
=
torch
.
cuda
.
max_memory_reserved
(
0
)
/
1024
**
3
assert
memory
<
memory_limit
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
tests/flux/test_flux_dev_loras.py
0 → 100644
View file @
0b1891cd
import
pytest
from
tests.flux.test_flux_dev
import
run_test_flux_dev
@
pytest
.
mark
.
parametrize
(
"num_inference_steps,lora_name,lora_scale,cpu_offload,expected_lpips"
,
[
(
25
,
"realism"
,
0.9
,
False
,
0.16
),
(
25
,
"ghibsky"
,
1
,
False
,
0.16
),
(
28
,
"anime"
,
1
,
False
,
0.27
),
(
24
,
"sketch"
,
1
,
False
,
0.35
),
(
28
,
"yarn"
,
1
,
False
,
0.22
),
(
25
,
"haunted_linework"
,
1
,
False
,
0.34
),
],
)
def
test_flux_dev_loras
(
num_inference_steps
,
lora_name
,
lora_scale
,
cpu_offload
,
expected_lpips
):
run_test_flux_dev
(
precision
=
"int4"
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
cpu_offload
=
cpu_offload
,
lora_name
=
lora_name
,
lora_scale
=
lora_scale
,
cache_threshold
=
0
,
max_dataset_size
=
8
,
expected_lpips
=
expected_lpips
,
)
def
test_flux_dev_hypersd8_1080x1920
():
run_test_flux_dev
(
precision
=
"int4"
,
height
=
1080
,
width
=
1920
,
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
cpu_offload
=
False
,
lora_name
=
"hypersd8"
,
lora_scale
=
0.125
,
cache_threshold
=
0
,
max_dataset_size
=
8
,
expected_lpips
=
0.44
,
)
tests/flux/test_flux_memory.py
0 → 100644
View file @
0b1891cd
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
@
pytest
.
mark
.
parametrize
(
"use_qencoder,cpu_offload,memory_limit"
,
[
(
False
,
False
,
17
),
(
False
,
True
,
13
),
(
True
,
False
,
12
),
(
True
,
True
,
6
),
],
)
def
test_flux_schnell_memory
(
use_qencoder
:
bool
,
cpu_offload
:
bool
,
memory_limit
:
float
):
torch
.
cuda
.
reset_peak_memory_stats
()
pipeline_init_kwargs
=
{
"transformer"
:
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
,
offload
=
cpu_offload
)
}
if
use_qencoder
:
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
).
to
(
"cuda"
)
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
0
)
memory
=
torch
.
cuda
.
max_memory_reserved
(
0
)
/
1024
**
3
assert
memory
<
memory_limit
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
tests/flux/test_flux_schnell.py
0 → 100644
View file @
0b1891cd
import
os
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
tests.data
import
get_dataset
from
tests.flux.utils
import
run_pipeline
from
tests.utils
import
already_generate
,
compute_lpips
@
pytest
.
mark
.
parametrize
(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips"
,
[
(
"int4"
,
1024
,
1024
,
4
,
0
,
False
,
False
,
16
,
0.258
),
(
"int4"
,
1024
,
1024
,
4
,
0
,
True
,
False
,
16
,
0.41
),
(
"int4"
,
1024
,
1024
,
4
,
0
,
True
,
False
,
16
,
0.41
),
(
"int4"
,
1920
,
1080
,
4
,
0
,
False
,
False
,
16
,
0.258
),
(
"int4"
,
600
,
800
,
4
,
0
,
False
,
False
,
16
,
0.29
),
],
)
def
test_flux_schnell
(
precision
:
str
,
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
use_qencoder
:
bool
,
cpu_offload
:
bool
,
max_dataset_size
:
int
,
expected_lpips
:
float
,
):
dataset
=
get_dataset
(
name
=
"MJHQ"
,
max_dataset_size
=
max_dataset_size
)
save_root
=
os
.
path
.
join
(
"results"
,
"schnell"
,
f
"w
{
width
}
h
{
height
}
t
{
num_inference_steps
}
g
{
guidance_scale
}
"
)
save_dir_16bit
=
os
.
path
.
join
(
save_root
,
"bf16"
)
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_16bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
save_dir_4bit
=
os
.
path
.
join
(
save_root
,
f
"
{
precision
}
-qencoder"
if
use_qencoder
else
f
"
{
precision
}
"
+
(
"-cpuoffload"
if
cpu_offload
else
""
)
)
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{}
if
precision
==
"int4"
:
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
,
offload
=
cpu_offload
)
else
:
assert
precision
==
"fp4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-fp4-flux.1-schnell"
,
precision
=
"fp4"
,
offload
=
cpu_offload
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
use_qencoder
:
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
if
cpu_offload
:
pipeline
.
enable_sequential_cpu_offload
()
run_pipeline
(
dataset
,
pipeline
,
save_dir
=
save_dir_4bit
,
forward_kwargs
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
num_inference_steps
,
"guidance_scale"
:
guidance_scale
,
},
)
del
pipeline
# release the gpu memory
torch
.
cuda
.
empty_cache
()
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.05
tests/flux/utils.py
0 → 100644
View file @
0b1891cd
import
os
import
torch
from
diffusers
import
FluxPipeline
from
tqdm
import
tqdm
from
..utils
import
hash_str_to_int
def
run_pipeline
(
dataset
,
pipeline
:
FluxPipeline
,
save_dir
:
str
,
forward_kwargs
:
dict
=
{}):
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
pipeline
.
set_progress_bar_config
(
desc
=
"Sampling"
,
leave
=
False
,
dynamic_ncols
=
True
,
position
=
1
)
for
row
in
tqdm
(
dataset
):
filename
=
row
[
"filename"
]
prompt
=
row
[
"prompt"
]
seed
=
hash_str_to_int
(
filename
)
image
=
pipeline
(
prompt
,
generator
=
torch
.
Generator
().
manual_seed
(
seed
),
**
forward_kwargs
).
images
[
0
]
image
.
save
(
os
.
path
.
join
(
save_dir
,
f
"
{
filename
}
.png"
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment