Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
5de6d7cf
Unverified
Commit
5de6d7cf
authored
Aug 14, 2025
by
SMG
Committed by
GitHub
Aug 13, 2025
Browse files
fix: enable correct batch processing in teacache (#601)
* fix teacache_batch * lint
parent
3bcc2d43
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
118 deletions
+46
-118
examples/flux.1-dev-teacache-batch.py
examples/flux.1-dev-teacache-batch.py
+39
-0
nunchaku/caching/teacache.py
nunchaku/caching/teacache.py
+5
-114
nunchaku/caching/utils.py
nunchaku/caching/utils.py
+2
-4
No files found.
examples/flux.1-dev-teacache-batch.py
0 → 100644
View file @
5de6d7cf
import
time
import
torch
from
diffusers.pipelines.flux.pipeline_flux
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.caching.teacache
import
TeaCache
from
nunchaku.utils
import
get_precision
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"nunchaku-tech/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
start_time
=
time
.
time
()
prompts
=
[
"A cheerful woman in a pastel dress, holding a basket of colorful Easter eggs with a sign that says 'Happy Easter'"
,
"A young peace activist with a gentle smile, holding a handmade sign that says 'Peace'"
,
"A friendly chef wearing a tall white hat, holding a wooden spoon with a sign that says 'Let's Cook!"
,
]
with
TeaCache
(
model
=
transformer
,
num_steps
=
50
,
rel_l1_thresh
=
0.3
,
enabled
=
True
):
image
=
pipeline
(
prompts
,
num_inference_steps
=
50
,
guidance_scale
=
3.5
,
height
=
1024
,
width
=
1024
,
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
0
),
).
images
end_time
=
time
.
time
()
print
(
f
"Time taken:
{
(
end_time
-
start_time
)
}
seconds"
)
image
[
0
].
save
(
f
"flux.1-dev-
{
precision
}
1-tc.png"
)
image
[
1
].
save
(
f
"flux.1-dev-
{
precision
}
2-tc.png"
)
image
[
2
].
save
(
f
"flux.1-dev-
{
precision
}
3-tc.png"
)
nunchaku/caching/teacache.py
View file @
5de6d7cf
...
...
@@ -161,16 +161,8 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
encoder_hidden_states
=
self
.
context_embedder
(
encoder_hidden_states
)
if
txt_ids
.
ndim
==
3
:
logger
.
warning
(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids
=
txt_ids
[
0
]
if
img_ids
.
ndim
==
3
:
logger
.
warning
(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids
=
img_ids
[
0
]
ids
=
torch
.
cat
((
txt_ids
,
img_ids
),
dim
=
0
)
...
...
@@ -250,60 +242,10 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
,
joint_attention_kwargs
=
joint_attention_kwargs
,
controlnet_block_samples
=
controlnet_block_samples
,
controlnet_single_block_samples
=
controlnet_single_block_samples
,
)
# controlnet residual
if
controlnet_block_samples
is
not
None
:
interval_control
=
len
(
self
.
transformer_blocks
)
/
len
(
controlnet_block_samples
)
interval_control
=
int
(
np
.
ceil
(
interval_control
))
# For Xlabs ControlNet.
if
controlnet_blocks_repeat
:
hidden_states
=
(
hidden_states
+
controlnet_block_samples
[
index_block
%
len
(
controlnet_block_samples
)]
)
else
:
hidden_states
=
hidden_states
+
controlnet_block_samples
[
index_block
//
interval_control
]
hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
for
index_block
,
block
in
enumerate
(
self
.
single_transformer_blocks
):
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
# type: ignore
def
custom_forward
(
*
inputs
):
# type: ignore
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
ckpt_kwargs
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
block
),
hidden_states
,
temb
,
image_rotary_emb
,
**
ckpt_kwargs
,
)
else
:
hidden_states
=
block
(
hidden_states
=
hidden_states
,
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
,
joint_attention_kwargs
=
joint_attention_kwargs
,
)
# controlnet residual
if
controlnet_single_block_samples
is
not
None
:
interval_control
=
len
(
self
.
single_transformer_blocks
)
/
len
(
controlnet_single_block_samples
)
interval_control
=
int
(
np
.
ceil
(
interval_control
))
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
=
(
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
+
controlnet_single_block_samples
[
index_block
//
interval_control
]
)
hidden_states
=
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
self
.
previous_residual
=
hidden_states
-
ori_hidden_states
else
:
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
...
...
@@ -335,61 +277,10 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
,
joint_attention_kwargs
=
joint_attention_kwargs
,
controlnet_block_samples
=
controlnet_block_samples
,
controlnet_single_block_samples
=
controlnet_single_block_samples
,
)
# controlnet residual
if
controlnet_block_samples
is
not
None
:
interval_control
=
len
(
self
.
transformer_blocks
)
/
len
(
controlnet_block_samples
)
interval_control
=
int
(
np
.
ceil
(
interval_control
))
# For Xlabs ControlNet.
if
controlnet_blocks_repeat
:
hidden_states
=
(
hidden_states
+
controlnet_block_samples
[
index_block
%
len
(
controlnet_block_samples
)]
)
else
:
hidden_states
=
hidden_states
+
controlnet_block_samples
[
index_block
//
interval_control
]
hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
for
index_block
,
block
in
enumerate
(
self
.
single_transformer_blocks
):
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
# type: ignore
def
custom_forward
(
*
inputs
):
# type: ignore
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
ckpt_kwargs
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
block
),
hidden_states
,
temb
,
image_rotary_emb
,
**
ckpt_kwargs
,
)
else
:
hidden_states
=
block
(
hidden_states
=
hidden_states
,
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
,
joint_attention_kwargs
=
joint_attention_kwargs
,
)
# controlnet residual
if
controlnet_single_block_samples
is
not
None
:
interval_control
=
len
(
self
.
single_transformer_blocks
)
/
len
(
controlnet_single_block_samples
)
interval_control
=
int
(
np
.
ceil
(
interval_control
))
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
=
(
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
+
controlnet_single_block_samples
[
index_block
//
interval_control
]
)
hidden_states
=
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:,
...]
hidden_states
=
self
.
norm_out
(
hidden_states
,
temb
)
output
:
torch
.
FloatTensor
=
self
.
proj_out
(
hidden_states
)
...
...
@@ -398,7 +289,7 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
unscale_lora_layers
(
self
,
lora_scale
)
if
not
return_dict
:
return
output
return
(
output
,)
return
Transformer2DModelOutput
(
sample
=
output
)
...
...
nunchaku/caching/utils.py
View file @
5de6d7cf
...
...
@@ -821,7 +821,7 @@ class FluxCachedTransformerBlocks(nn.Module):
-----
If batch size > 2 or residual_diff_threshold <= 0, caching is disabled for now.
"""
batch_size
=
hidden_states
.
shape
[
0
]
#
batch_size = hidden_states.shape[0]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
...
...
@@ -860,9 +860,7 @@ class FluxCachedTransformerBlocks(nn.Module):
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_single
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_single
,
256
,
1
))
if
(
self
.
residual_diff_threshold_multi
<
0.0
)
or
(
batch_size
>
1
):
if
batch_size
>
1
and
self
.
verbose
:
print
(
"Batch size > 1 currently not supported"
)
if
self
.
residual_diff_threshold_multi
<
0.0
:
hidden_states
=
self
.
m
.
forward
(
hidden_states
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment