Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
ca4d4b46
Unverified
Commit
ca4d4b46
authored
Jun 06, 2025
by
SMG
Committed by
GitHub
Jun 05, 2025
Browse files
fix: segmentation fault when using FBcache with offload=True (#440)
* fix:cache issue if offload is set to True * fix: lint
parent
7f71d3ac
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
0 deletions
+55
-0
examples/flux.1-dev-double_cache_offloading.py
examples/flux.1-dev-double_cache_offloading.py
+28
-0
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+8
-0
src/FluxModel.cpp
src/FluxModel.cpp
+16
-0
src/FluxModel.h
src/FluxModel.h
+3
-0
No files found.
examples/flux.1-dev-double_cache_offloading.py
0 → 100644
View file @
ca4d4b46
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
from
nunchaku.utils
import
get_precision
precision
=
get_precision
()
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
,
offload
=
True
,
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
apply_cache_on_pipe
(
pipeline
,
use_double_fb_cache
=
True
,
residual_diff_threshold_multi
=
0.09
,
residual_diff_threshold_single
=
0.12
,
)
image
=
pipeline
([
"A cat holding a sign that says hello world"
],
num_inference_steps
=
50
).
images
[
0
]
image
.
save
(
f
"flux.1-dev-cache-
{
precision
}
.png"
)
nunchaku/csrc/flux.h
View file @
ca4d4b46
...
@@ -143,9 +143,17 @@ public:
...
@@ -143,9 +143,17 @@ public:
temb
=
temb
.
contiguous
();
temb
=
temb
.
contiguous
();
rotary_emb_single
=
rotary_emb_single
.
contiguous
();
rotary_emb_single
=
rotary_emb_single
.
contiguous
();
if
(
net
->
isOffloadEnabled
())
{
net
->
single_transformer_blocks
.
at
(
idx
)
->
loadLazyParams
();
}
Tensor
result
=
net
->
single_transformer_blocks
.
at
(
idx
)
->
forward
(
Tensor
result
=
net
->
single_transformer_blocks
.
at
(
idx
)
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
temb
),
from_torch
(
rotary_emb_single
));
from_torch
(
hidden_states
),
from_torch
(
temb
),
from_torch
(
rotary_emb_single
));
if
(
net
->
isOffloadEnabled
())
{
net
->
single_transformer_blocks
.
at
(
idx
)
->
releaseLazyParams
();
}
hidden_states
=
to_torch
(
result
);
hidden_states
=
to_torch
(
result
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
...
...
src/FluxModel.cpp
View file @
ca4d4b46
...
@@ -919,6 +919,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
...
@@ -919,6 +919,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
Tensor
controlnet_block_samples
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
Tensor
controlnet_single_block_samples
)
{
if
(
offload
&&
layer
>
0
)
{
if
(
layer
<
transformer_blocks
.
size
())
{
transformer_blocks
.
at
(
layer
)
->
loadLazyParams
();
}
else
{
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
loadLazyParams
();
}
}
if
(
layer
<
transformer_blocks
.
size
())
{
if
(
layer
<
transformer_blocks
.
size
())
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
...
@@ -954,6 +962,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
...
@@ -954,6 +962,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
if
(
offload
&&
layer
>
0
)
{
if
(
layer
<
transformer_blocks
.
size
())
{
transformer_blocks
.
at
(
layer
)
->
releaseLazyParams
();
}
else
{
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
releaseLazyParams
();
}
}
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
}
...
...
src/FluxModel.h
View file @
ca4d4b46
...
@@ -189,6 +189,9 @@ public:
...
@@ -189,6 +189,9 @@ public:
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
bool
isOffloadEnabled
()
const
{
return
offload
;
}
private:
private:
bool
offload
;
bool
offload
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment