Commit 39f90121 authored by Hyunsung Lee's avatar Hyunsung Lee Committed by Zhekai Zhang
Browse files

Add dynamic Caching when batch_size = 1 for flux model

parent 804a6d30
import torch
from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
import time
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline, residual_diff_threshold=0.12)
image = pipeline(
["A cat holding a sign that says hello world"],
width=1024,
height=1024,
num_inference_steps=32,
guidance_scale=0
).images[0]
image.save("flux.1-schnell-int4-0.12.png")
import importlib
from diffusers import DiffusionPipeline
def apply_cache_on_transformer(transformer, *args, **kwargs):
transformer_cls_name = transformer.__class__.__name__
if False:
pass
elif transformer_cls_name.startswith("Flux"):
adapter_name = "flux"
else:
raise ValueError(f"Unknown transformer class name: {transformer_cls_name}")
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
assert isinstance(pipe, DiffusionPipeline)
pipe_cls_name = pipe.__class__.__name__
if False:
pass
elif pipe_cls_name.startswith("Flux"):
adapter_name = "flux"
else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
print("Registering Flux")
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
import functools
import unittest
import torch
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from nunchaku.caching import utils
def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
residual_diff_threshold=0.05,
):
if getattr(transformer, "_is_cached", False):
return transformer
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.CachedTransformerBlocks(
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
return_hidden_states_first=False,
)
]
)
dummy_single_transformer_blocks = torch.nn.ModuleList()
original_forward = transformer.forward
@functools.wraps(original_forward)
def new_forward(
self,
*args,
**kwargs,
):
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
), unittest.mock.patch.object(
self,
"single_transformer_blocks",
dummy_single_transformer_blocks,
):
return original_forward(
*args,
**kwargs,
)
transformer.forward = new_forward.__get__(transformer)
transformer._is_cached = True
return transformer
def apply_cache_on_pipe(
pipe: DiffusionPipeline,
*,
shallow_patch: bool = False,
**kwargs,
):
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
# This cachaing functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict
import torch
@dataclasses.dataclass
class CacheContext:
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
def get_incremental_name(self, name=None):
if name is None:
name = "default"
idx = self.incremental_name_counters[name]
self.incremental_name_counters[name] += 1
return f"{name}_{idx}"
def reset_incremental_name(self):
self.incremental_name_counters.clear()
# @torch.compiler.disable # This is a torchscript feature
def get_buffer(self, name=str):
return self.buffers.get(name)
def set_buffer(self, name, buffer):
self.buffers[name] = buffer
def clear_buffers(self):
self.buffers.clear()
@torch.compiler.disable
def get_buffer(name):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
return cache_context.get_buffer(name)
@torch.compiler.disable
def set_buffer(name, buffer):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
cache_context.set_buffer(name, buffer)
_current_cache_context = None
def create_cache_context():
return CacheContext()
def get_current_cache_context():
return _current_cache_context
@contextlib.contextmanager
def cache_context(cache_context):
global _current_cache_context
old_cache_context = _current_cache_context
_current_cache_context = cache_context
try:
yield
finally:
_current_cache_context = old_cache_context
@torch.compiler.disable
def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = mean_diff / mean_t1
return diff.item() < threshold
@torch.compiler.disable
def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
hidden_states_residual = get_buffer("hidden_states_residual")
assert hidden_states_residual is not None, "hidden_states_residual must be set before"
hidden_states = hidden_states_residual + hidden_states
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
@torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
prev_first_hidden_states_residual,
first_hidden_states_residual,
threshold=threshold,
parallelized=parallelized,
)
return can_use_cache
class CachedTransformerBlocks(torch.nn.Module):
def __init__(
self,
*,
transformer=None,
residual_diff_threshold,
return_hidden_states_first=True,
return_hidden_states_only=False,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks
self.single_transformer_blocks = transformer.single_transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.return_hidden_states_first = return_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 1:
if batch_size > 1:
print("Batch size > 1 currently not supported")
first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, *args, **kwargs)
return (
hidden_states
if self.return_hidden_states_only
else (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
)
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block.forward_layer_at(
0, hidden_states, encoder_hidden_states, *args, **kwargs)
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
print("Cache hit!!!")
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states
)
else:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
(
hidden_states,
encoder_hidden_states,
hidden_states_residual,
encoder_hidden_states_residual,
) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
torch._dynamo.graph_break()
return (
hidden_states
if self.return_hidden_states_only
else (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
)
def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
encoder_hidden_states, hidden_states = first_transformer_block.forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
skip_first_layer=True, *args, **kwargs)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
hidden_states_residual = hidden_states - original_hidden_states
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
......@@ -18,12 +18,13 @@ public:
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single)
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single,
bool skip_first_layer = false)
{
checkModel();
......@@ -42,7 +43,8 @@ public:
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
from_torch(rotary_emb_single)
from_torch(rotary_emb_single),
skip_first_layer
);
torch::Tensor output = to_torch(result);
......@@ -53,10 +55,10 @@ public:
std::tuple<torch::Tensor, torch::Tensor> forward_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context)
{
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
......@@ -85,8 +87,8 @@ public:
torch::Tensor forward_single_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_single)
{
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
......
......@@ -25,12 +25,12 @@ class NunchakuFluxTransformerBlocks(nn.Module):
def forward(
self,
/,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
skip_first_layer=False,
):
batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
......@@ -59,7 +59,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_single = pad_tensor(rotary_emb_single, 256, 1)
hidden_states = self.m.forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single, skip_first_layer
)
hidden_states = hidden_states.to(original_dtype).to(original_device)
......@@ -69,6 +69,47 @@ class NunchakuFluxTransformerBlocks(nn.Module):
return encoder_hidden_states, hidden_states
def forward_layer_at(
self,
idx: int,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
):
batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
original_dtype = hidden_states.dtype
original_device = hidden_states.device
hidden_states = hidden_states.to(self.dtype).to(self.device)
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device)
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_txt = pad_tensor(rotary_emb_txt, 256, 1)
rotary_emb_img = pad_tensor(rotary_emb_img, 256, 1)
hidden_states, encoder_hidden_states = self.m.forward_layer(
0, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt)
hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
return encoder_hidden_states, hidden_states
## copied from diffusers 0.30.3
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
......@@ -204,6 +245,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self.update_unquantized_lora_params(strength)
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
print("Injecting quantized module")
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
### Compatible with the original forward method
......
......@@ -4,6 +4,7 @@ import torch
from diffusers import __version__
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from typing import Optional, Any
class NunchakuModelLoaderMixin:
......@@ -65,12 +66,10 @@ class NunchakuModelLoaderMixin:
return transformer, transformer_block_path
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def pad_tensor(tensor: torch.Tensor | None, multiples: int, dim: int, fill=0) -> torch.Tensor:
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor:
if multiples <= 1:
return tensor
if tensor is None:
......
......@@ -410,7 +410,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
auto stream = getCurrentCUDAStream();
Tensor concat;
Tensor pool;
{
nvtxRangePushA("qkv_proj");
......@@ -422,16 +422,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
pool = blockSparse
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
: Tensor{};
for (int i = 0; i < batch_size; i++) {
// img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_context);
Tensor pool_qkv = pool.valid()
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
Tensor pool_qkv = pool.valid()
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
: Tensor{};
Tensor pool_qkv_context = pool.valid()
Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_context / POOL_SIZE)
: Tensor{};
......@@ -626,7 +626,7 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single) {
Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer) {
const int batch_size = hidden_states.shape[0];
const Tensor::ScalarType dtype = hidden_states.dtype();
const Device device = hidden_states.device();
......@@ -639,6 +639,7 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
Tensor concat;
auto compute = [&](int layer) {
if (skip_first_layer && size_t(layer) == 0) return;
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
......
......@@ -129,7 +129,7 @@ private:
class FluxModel : public Module {
public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer = false);
public:
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment