Commit 0b1891cd authored by muyangli's avatar muyangli Committed by Zhekai Zhang
Browse files

[feat] add first block cache

parent 39f90121
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
import time transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev", offload=True)
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") )
pipeline.enable_sequential_cpu_offload()
apply_cache_on_pipe( apply_cache_on_pipe(pipeline, residual_diff_threshold=0.12)
pipeline, residual_diff_threshold=0.12) image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
image = pipeline( image.save("flux.1-dev-int4.png")
["A cat holding a sign that says hello world"],
width=1024,
height=1024,
num_inference_steps=32,
guidance_scale=0
).images[0]
image.save("flux.1-schnell-int4-0.12.png")
__version__ = "0.1.4" __version__ = "0.1.5"
import importlib
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
def apply_cache_on_transformer(transformer, *args, **kwargs):
transformer_cls_name = transformer.__class__.__name__
if False:
pass
elif transformer_cls_name.startswith("Flux"):
adapter_name = "flux"
else:
raise ValueError(f"Unknown transformer class name: {transformer_cls_name}")
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs): def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
assert isinstance(pipe, DiffusionPipeline) assert isinstance(pipe, DiffusionPipeline)
pipe_cls_name = pipe.__class__.__name__ pipe_cls_name = pipe.__class__.__name__
if False: if pipe_cls_name.startswith("Flux"):
pass from .flux import apply_cache_on_pipe as apply_cache_on_pipe_fn
elif pipe_cls_name.startswith("Flux"):
adapter_name = "flux"
else: else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}") raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
print("Registering Flux")
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
return apply_cache_on_pipe_fn(pipe, *args, **kwargs) return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
...@@ -4,14 +4,10 @@ import unittest ...@@ -4,14 +4,10 @@ import unittest
import torch import torch
from diffusers import DiffusionPipeline, FluxTransformer2DModel from diffusers import DiffusionPipeline, FluxTransformer2DModel
from nunchaku.caching import utils from ...caching import utils
def apply_cache_on_transformer( def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12):
transformer: FluxTransformer2DModel,
*,
residual_diff_threshold=0.05,
):
if getattr(transformer, "_is_cached", False): if getattr(transformer, "_is_cached", False):
return transformer return transformer
...@@ -29,38 +25,20 @@ def apply_cache_on_transformer( ...@@ -29,38 +25,20 @@ def apply_cache_on_transformer(
original_forward = transformer.forward original_forward = transformer.forward
@functools.wraps(original_forward) @functools.wraps(original_forward)
def new_forward( def new_forward(self, *args, **kwargs):
self, with (
*args, unittest.mock.patch.object(self, "transformer_blocks", cached_transformer_blocks),
**kwargs, unittest.mock.patch.object(self, "single_transformer_blocks", dummy_single_transformer_blocks),
):
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
), unittest.mock.patch.object(
self,
"single_transformer_blocks",
dummy_single_transformer_blocks,
): ):
return original_forward( return original_forward(*args, **kwargs)
*args,
**kwargs,
)
transformer.forward = new_forward.__get__(transformer) transformer.forward = new_forward.__get__(transformer)
transformer._is_cached = True transformer._is_cached = True
return transformer return transformer
def apply_cache_on_pipe( def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
pipe: DiffusionPipeline,
*,
shallow_patch: bool = False,
**kwargs,
):
if not getattr(pipe, "_is_cached", False): if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__ original_call = pipe.__class__.__call__
......
# This cachaing functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/ # This caching functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
import contextlib import contextlib
import dataclasses import dataclasses
...@@ -6,6 +6,7 @@ from collections import defaultdict ...@@ -6,6 +6,7 @@ from collections import defaultdict
from typing import DefaultDict, Dict from typing import DefaultDict, Dict
import torch import torch
from torch import nn
@dataclasses.dataclass @dataclasses.dataclass
...@@ -34,7 +35,6 @@ class CacheContext: ...@@ -34,7 +35,6 @@ class CacheContext:
self.buffers.clear() self.buffers.clear()
@torch.compiler.disable @torch.compiler.disable
def get_buffer(name): def get_buffer(name):
cache_context = get_current_cache_context() cache_context = get_current_cache_context()
...@@ -49,7 +49,6 @@ def set_buffer(name, buffer): ...@@ -49,7 +49,6 @@ def set_buffer(name, buffer):
cache_context.set_buffer(name, buffer) cache_context.set_buffer(name, buffer)
_current_cache_context = None _current_cache_context = None
...@@ -79,8 +78,11 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False): ...@@ -79,8 +78,11 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
diff = mean_diff / mean_t1 diff = mean_diff / mean_t1
return diff.item() < threshold return diff.item() < threshold
@torch.compiler.disable @torch.compiler.disable
def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states): def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states_residual = get_buffer("hidden_states_residual") hidden_states_residual = get_buffer("hidden_states_residual")
assert hidden_states_residual is not None, "hidden_states_residual must be set before" assert hidden_states_residual is not None, "hidden_states_residual must be set before"
hidden_states = hidden_states_residual + hidden_states hidden_states = hidden_states_residual + hidden_states
...@@ -94,6 +96,7 @@ def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states): ...@@ -94,6 +96,7 @@ def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
@torch.compiler.disable @torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False): def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual") prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
...@@ -105,7 +108,8 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals ...@@ -105,7 +108,8 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
) )
return can_use_cache return can_use_cache
class CachedTransformerBlocks(torch.nn.Module):
class CachedTransformerBlocks(nn.Module):
def __init__( def __init__(
self, self,
*, *,
...@@ -113,6 +117,7 @@ class CachedTransformerBlocks(torch.nn.Module): ...@@ -113,6 +117,7 @@ class CachedTransformerBlocks(torch.nn.Module):
residual_diff_threshold, residual_diff_threshold,
return_hidden_states_first=True, return_hidden_states_first=True,
return_hidden_states_only=False, return_hidden_states_only=False,
verbose: bool = False,
): ):
super().__init__() super().__init__()
self.transformer = transformer self.transformer = transformer
...@@ -121,6 +126,7 @@ class CachedTransformerBlocks(torch.nn.Module): ...@@ -121,6 +126,7 @@ class CachedTransformerBlocks(torch.nn.Module):
self.residual_diff_threshold = residual_diff_threshold self.residual_diff_threshold = residual_diff_threshold
self.return_hidden_states_first = return_hidden_states_first self.return_hidden_states_first = return_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only self.return_hidden_states_only = return_hidden_states_only
self.verbose = verbose
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs): def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
...@@ -130,7 +136,8 @@ class CachedTransformerBlocks(torch.nn.Module): ...@@ -130,7 +136,8 @@ class CachedTransformerBlocks(torch.nn.Module):
first_transformer_block = self.transformer_blocks[0] first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block( encoder_hidden_states, hidden_states = first_transformer_block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, *args, **kwargs) hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, *args, **kwargs
)
return ( return (
hidden_states hidden_states
...@@ -145,7 +152,8 @@ class CachedTransformerBlocks(torch.nn.Module): ...@@ -145,7 +152,8 @@ class CachedTransformerBlocks(torch.nn.Module):
original_hidden_states = hidden_states original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0] first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block.forward_layer_at( encoder_hidden_states, hidden_states = first_transformer_block.forward_layer_at(
0, hidden_states, encoder_hidden_states, *args, **kwargs) 0, hidden_states, encoder_hidden_states, *args, **kwargs
)
first_hidden_states_residual = hidden_states - original_hidden_states first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states del original_hidden_states
...@@ -159,12 +167,14 @@ class CachedTransformerBlocks(torch.nn.Module): ...@@ -159,12 +167,14 @@ class CachedTransformerBlocks(torch.nn.Module):
torch._dynamo.graph_break() torch._dynamo.graph_break()
if can_use_cache: if can_use_cache:
del first_hidden_states_residual del first_hidden_states_residual
print("Cache hit!!!") if self.verbose:
print("Cache hit!!!")
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual( hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states hidden_states, encoder_hidden_states
) )
else: else:
print("Cache miss!!!") if self.verbose:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual) set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual del first_hidden_states_residual
( (
...@@ -192,9 +202,12 @@ class CachedTransformerBlocks(torch.nn.Module): ...@@ -192,9 +202,12 @@ class CachedTransformerBlocks(torch.nn.Module):
original_hidden_states = hidden_states original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states original_encoder_hidden_states = encoder_hidden_states
encoder_hidden_states, hidden_states = first_transformer_block.forward( encoder_hidden_states, hidden_states = first_transformer_block.forward(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
skip_first_layer=True, *args, **kwargs) skip_first_layer=True,
*args,
**kwargs,
)
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous() encoder_hidden_states = encoder_hidden_states.contiguous()
......
...@@ -59,7 +59,13 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -59,7 +59,13 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_single = pad_tensor(rotary_emb_single, 256, 1) rotary_emb_single = pad_tensor(rotary_emb_single, 256, 1)
hidden_states = self.m.forward( hidden_states = self.m.forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single, skip_first_layer hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
rotary_emb_single,
skip_first_layer,
) )
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
...@@ -103,7 +109,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -103,7 +109,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = pad_tensor(rotary_emb_img, 256, 1) rotary_emb_img = pad_tensor(rotary_emb_img, 256, 1)
hidden_states, encoder_hidden_states = self.m.forward_layer( hidden_states, encoder_hidden_states = self.m.forward_layer(
0, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt) idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt
)
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device) encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
......
...@@ -9,11 +9,13 @@ if __name__ == "__main__": ...@@ -9,11 +9,13 @@ if __name__ == "__main__":
precision = "fp4" if sm == "120" else "int4" precision = "fp4" if sm == "120" else "int4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", precision=precision f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=True, precision=precision
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") )
image = pipeline( image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0 "A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0] ).images[0]
image.save("flux.1-schnell.png")
param (
[string]$PYTHON_VERSION,
[string]$TORCH_VERSION,
[string]$CUDA_VERSION,
[string]$MAX_JOBS = ""
)
# Conda 环境名称
$ENV_NAME = "build_env_$PYTHON_VERSION"
# 创建 Conda 环境
conda create -y -n $ENV_NAME python=$PYTHON_VERSION
conda activate $ENV_NAME
# 安装依赖
conda install -y ninja setuptools wheel pip
pip install --no-cache-dir torch==$TORCH_VERSION numpy --index-url "https://download.pytorch.org/whl/cu$($CUDA_VERSION.Substring(0,2))/"
# 设置环境变量
$env:NUNCHAKU_INSTALL_MODE="ALL"
$env:NUNCHAKU_BUILD_WHEELS="1"
$env:MAX_JOBS=$MAX_JOBS
# 进入当前脚本所在目录并构建 wheels
Set-Location -Path $PSScriptRoot
if (Test-Path "build") { Remove-Item -Recurse -Force "build" }
python -m build --wheel --no-isolation
# 退出 Conda 环境
conda deactivate
Write-Output "Build complete!"
...@@ -485,16 +485,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -485,16 +485,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
} else { } else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device()); raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync( checkCUDA(cudaMemcpy2DAsync(
raw_attn_output_split.data_ptr(), raw_attn_output_split.data_ptr(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr(), raw_attn_output.data_ptr(),
(num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(), (num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size, batch_size,
cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToDevice,
stream)); stream));
} }
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str()); spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("img.raw_attn_output_split", raw_attn_output_split); debug("img.raw_attn_output_split", raw_attn_output_split);
...@@ -550,16 +550,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -550,16 +550,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
} else { } else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_context, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device()); raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_context, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync( checkCUDA(cudaMemcpy2DAsync(
raw_attn_output_split.data_ptr(), raw_attn_output_split.data_ptr(),
num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr<char>() + num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(), raw_attn_output.data_ptr<char>() + num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
(num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(), (num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size, batch_size,
cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToDevice,
stream)); stream));
} }
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str()); spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("context.raw_attn_output_split", raw_attn_output_split); debug("context.raw_attn_output_split", raw_attn_output_split);
...@@ -585,7 +585,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -585,7 +585,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
#else #else
auto norm_hidden_states = encoder_hidden_states; auto norm_hidden_states = encoder_hidden_states;
#endif #endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states))); // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0)); // Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
......
import pytest
from .test_flux_dev import run_test_flux_dev
@pytest.mark.parametrize(
"height,width,num_inference_steps,cache_threshold,lora_name,use_qencoder,cpu_offload,expected_lpips",
[
# (1024, 1024, 50, 0, None, False, False, 0.5), # 13min20s 5min55s 0.19539418816566467
# (1024, 1024, 50, 0.05, None, False, True, 0.5), # 7min11s 0.21917256712913513
# (1024, 1024, 50, 0.12, None, False, True, 0.5), # 2min58s, 0.24101486802101135
# (1024, 1024, 50, 0.2, None, False, True, 0.5), # 2min23s, 0.3101634383201599
# (1024, 1024, 50, 0.5, None, False, True, 0.5), # 1min44s 0.6543852090835571
# (1024, 1024, 30, 0, None, False, False, 0.5), # 8min2s 3min40s 0.2141970843076706
# (1024, 1024, 30, 0.05, None, False, True, 0.5), # 4min57 0.21297718584537506
# (1024, 1024, 30, 0.12, None, False, True, 0.5), # 2min34 0.25963714718818665
# (1024, 1024, 30, 0.2, None, False, True, 0.5), # 1min51 0.31409069895744324
# (1024, 1024, 20, 0, None, False, False, 0.5), # 5min25 2min29 0.18987375497817993
# (1024, 1024, 20, 0.05, None, False, True, 0.5), # 3min3 0.17194810509681702
# (1024, 1024, 20, 0.12, None, False, True, 0.5), # 2min15 0.19407868385314941
# (1024, 1024, 20, 0.2, None, False, True, 0.5), # 1min48 0.2832985818386078
(1024, 1024, 30, 0.12, None, False, False, 0.26),
(512, 2048, 30, 0.12, "anime", True, False, 0.4),
],
)
def test_flux_dev_base(
height: int,
width: int,
num_inference_steps: int,
cache_threshold: float,
lora_name: str | None,
use_qencoder: bool,
cpu_offload: bool,
expected_lpips: float,
):
run_test_flux_dev(
precision="int4",
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=use_qencoder,
cpu_offload=cpu_offload,
lora_name=lora_name,
lora_scale=1,
cache_threshold=cache_threshold,
max_dataset_size=16,
expected_lpips=expected_lpips,
)
...@@ -6,111 +6,13 @@ import torch ...@@ -6,111 +6,13 @@ import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from peft.tuners import lora from peft.tuners import lora
from safetensors.torch import save_file from safetensors.torch import save_file
from tqdm import tqdm
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.lora.flux import comfyui2diffusers, convert_to_nunchaku_flux_lowrank_dict, detect_format, xlab2diffusers from nunchaku.lora.flux import comfyui2diffusers, convert_to_nunchaku_flux_lowrank_dict, detect_format, xlab2diffusers
from .utils import run_pipeline
from ..data import get_dataset from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int from ..utils import already_generate, compute_lpips
def run_pipeline(dataset, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
seed = hash_str_to_int(filename)
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png"))
@pytest.mark.parametrize(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips",
[
("int4", 1024, 1024, 4, 0, False, False, 16, 0.258),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1920, 1080, 4, 0, False, False, 16, 0.258),
("int4", 600, 800, 4, 0, False, False, 16, 0.29),
],
)
def test_flux_schnell(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
max_dataset_size: int,
expected_lpips: float,
):
dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
save_root = os.path.join("results", "schnell", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(
save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4", offload=cpu_offload
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
LORA_PATH_MAP = { LORA_PATH_MAP = {
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors", "hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors",
...@@ -133,6 +35,7 @@ def run_test_flux_dev( ...@@ -133,6 +35,7 @@ def run_test_flux_dev(
cpu_offload: bool, cpu_offload: bool,
lora_name: str | None, lora_name: str | None,
lora_scale: float, lora_scale: float,
cache_threshold: float,
max_dataset_size: int, max_dataset_size: int,
expected_lpips: float, expected_lpips: float,
): ):
...@@ -140,7 +43,6 @@ def run_test_flux_dev( ...@@ -140,7 +43,6 @@ def run_test_flux_dev(
"results", "results",
"dev", "dev",
f"w{width}h{height}t{num_inference_steps}g{guidance_scale}" f"w{width}h{height}t{num_inference_steps}g{guidance_scale}"
+ ("-qencoder" if use_qencoder else "")
+ (f"-{lora_name}_{lora_scale:.1f}" if lora_name else ""), + (f"-{lora_name}_{lora_scale:.1f}" if lora_name else ""),
) )
dataset = get_dataset( dataset = get_dataset(
...@@ -177,7 +79,12 @@ def run_test_flux_dev( ...@@ -177,7 +79,12 @@ def run_test_flux_dev(
# release the gpu memory # release the gpu memory
torch.cuda.empty_cache() torch.cuda.empty_cache()
save_dir_4bit = os.path.join(save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}") name = precision
name += "-qencoder" if use_qencoder else ""
name += "-offload" if cpu_offload else ""
name += f"-cache{cache_threshold:.2f}" if cache_threshold > 0 else ""
save_dir_4bit = os.path.join(save_root, name)
if not already_generate(save_dir_4bit, max_dataset_size): if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
if precision == "int4": if precision == "int4":
...@@ -221,6 +128,9 @@ def run_test_flux_dev( ...@@ -221,6 +128,9 @@ def run_test_flux_dev(
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
if cpu_offload: if cpu_offload:
pipeline.enable_sequential_cpu_offload() pipeline.enable_sequential_cpu_offload()
if cache_threshold > 0:
apply_cache_on_pipe(pipeline, residual_diff_threshold=cache_threshold)
run_pipeline( run_pipeline(
dataset, dataset,
pipeline, pipeline,
...@@ -252,6 +162,7 @@ def test_flux_dev_base(cpu_offload: bool): ...@@ -252,6 +162,7 @@ def test_flux_dev_base(cpu_offload: bool):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
lora_name=None, lora_name=None,
lora_scale=0, lora_scale=0,
cache_threshold=0,
max_dataset_size=8, max_dataset_size=8,
expected_lpips=0.16, expected_lpips=0.16,
) )
...@@ -268,85 +179,7 @@ def test_flux_dev_qencoder_800x600(): ...@@ -268,85 +179,7 @@ def test_flux_dev_qencoder_800x600():
cpu_offload=False, cpu_offload=False,
lora_name=None, lora_name=None,
lora_scale=0, lora_scale=0,
cache_threshold=0,
max_dataset_size=8, max_dataset_size=8,
expected_lpips=0.36, expected_lpips=0.36,
) )
def test_flux_dev_hypersd8_1080x1920():
run_test_flux_dev(
precision="int4",
height=1080,
width=1920,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_name="hypersd8",
lora_scale=0.125,
max_dataset_size=8,
expected_lpips=0.44,
)
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_scale,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, False, 0.16),
(25, "ghibsky", 1, False, 0.16),
(28, "anime", 1, False, 0.27),
(24, "sketch", 1, False, 0.35),
(28, "yarn", 1, False, 0.22),
(25, "haunted_linework", 1, False, 0.34),
],
)
def test_flux_dev_loras(num_inference_steps, lora_name, lora_scale, cpu_offload, expected_lpips):
run_test_flux_dev(
precision="int4",
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_name=lora_name,
lora_scale=lora_scale,
max_dataset_size=8,
expected_lpips=expected_lpips,
)
@pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit",
[
(False, False, 17),
(False, True, 13),
(True, False, 12),
(True, True, 6),
],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
torch.cuda.reset_peak_memory_stats()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
}
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
)
memory = torch.cuda.max_memory_reserved(0) / 1024**3
assert memory < memory_limit
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
import pytest
from tests.flux.test_flux_dev import run_test_flux_dev
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_scale,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, False, 0.16),
(25, "ghibsky", 1, False, 0.16),
(28, "anime", 1, False, 0.27),
(24, "sketch", 1, False, 0.35),
(28, "yarn", 1, False, 0.22),
(25, "haunted_linework", 1, False, 0.34),
],
)
def test_flux_dev_loras(num_inference_steps, lora_name, lora_scale, cpu_offload, expected_lpips):
run_test_flux_dev(
precision="int4",
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_name=lora_name,
lora_scale=lora_scale,
cache_threshold=0,
max_dataset_size=8,
expected_lpips=expected_lpips,
)
def test_flux_dev_hypersd8_1080x1920():
run_test_flux_dev(
precision="int4",
height=1080,
width=1920,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_name="hypersd8",
lora_scale=0.125,
cache_threshold=0,
max_dataset_size=8,
expected_lpips=0.44,
)
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
@pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit",
[
(False, False, 17),
(False, True, 13),
(True, False, 12),
(True, True, 6),
],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
torch.cuda.reset_peak_memory_stats()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
}
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
)
memory = torch.cuda.max_memory_reserved(0) / 1024**3
assert memory < memory_limit
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
import os
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from tests.data import get_dataset
from tests.flux.utils import run_pipeline
from tests.utils import already_generate, compute_lpips
@pytest.mark.parametrize(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips",
[
("int4", 1024, 1024, 4, 0, False, False, 16, 0.258),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1920, 1080, 4, 0, False, False, 16, 0.258),
("int4", 600, 800, 4, 0, False, False, 16, 0.29),
],
)
def test_flux_schnell(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
max_dataset_size: int,
expected_lpips: float,
):
dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
save_root = os.path.join("results", "schnell", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(
save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4", offload=cpu_offload
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
import os
import torch
from diffusers import FluxPipeline
from tqdm import tqdm
from ..utils import hash_str_to_int
def run_pipeline(dataset, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
seed = hash_str_to_int(filename)
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment