Unverified Commit ca4d4b46 authored by SMG's avatar SMG Committed by GitHub
Browse files

fix: segmentation fault when using FBcache with offload=True (#440)

* fix:cache issue if offload is set to True

* fix: lint
parent 7f71d3ac
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
offload=True,
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
......@@ -143,9 +143,17 @@ public:
temb = temb.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
if (net->isOffloadEnabled()) {
net->single_transformer_blocks.at(idx)->loadLazyParams();
}
Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
if (net->isOffloadEnabled()) {
net->single_transformer_blocks.at(idx)->releaseLazyParams();
}
hidden_states = to_torch(result);
Tensor::synchronizeDevice();
......
......@@ -919,6 +919,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples) {
if (offload && layer > 0) {
if (layer < transformer_blocks.size()) {
transformer_blocks.at(layer)->loadLazyParams();
} else {
transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
}
}
if (layer < transformer_blocks.size()) {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
......@@ -954,6 +962,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
if (offload && layer > 0) {
if (layer < transformer_blocks.size()) {
transformer_blocks.at(layer)->releaseLazyParams();
} else {
transformer_blocks.at(layer - transformer_blocks.size())->releaseLazyParams();
}
}
return {hidden_states, encoder_hidden_states};
}
......
......@@ -189,6 +189,9 @@ public:
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
std::function<Tensor(const Tensor &)> residual_callback;
bool isOffloadEnabled() const {
return offload;
}
private:
bool offload;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment