Unverified Commit 5de6d7cf authored by SMG's avatar SMG Committed by GitHub
Browse files

fix: enable correct batch processing in teacache (#601)

* fix teacache_batch

* lint
parent 3bcc2d43
import time
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
start_time = time.time()
prompts = [
"A cheerful woman in a pastel dress, holding a basket of colorful Easter eggs with a sign that says 'Happy Easter'",
"A young peace activist with a gentle smile, holding a handmade sign that says 'Peace'",
"A friendly chef wearing a tall white hat, holding a wooden spoon with a sign that says 'Let's Cook!",
]
with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True):
image = pipeline(
prompts,
num_inference_steps=50,
guidance_scale=3.5,
height=1024,
width=1024,
generator=torch.Generator(device="cuda").manual_seed(0),
).images
end_time = time.time()
print(f"Time taken: {(end_time - start_time)} seconds")
image[0].save(f"flux.1-dev-{precision}1-tc.png")
image[1].save(f"flux.1-dev-{precision}2-tc.png")
image[2].save(f"flux.1-dev-{precision}3-tc.png")
......@@ -161,16 +161,8 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
......@@ -250,60 +242,10 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
self.previous_residual = hidden_states - ori_hidden_states
else:
for index_block, block in enumerate(self.transformer_blocks):
......@@ -335,61 +277,10 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output: torch.FloatTensor = self.proj_out(hidden_states)
......@@ -398,7 +289,7 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
unscale_lora_layers(self, lora_scale)
if not return_dict:
return output
return (output,)
return Transformer2DModelOutput(sample=output)
......
......@@ -821,7 +821,7 @@ class FluxCachedTransformerBlocks(nn.Module):
-----
If batch size > 2 or residual_diff_threshold <= 0, caching is disabled for now.
"""
batch_size = hidden_states.shape[0]
# batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
......@@ -860,9 +860,7 @@ class FluxCachedTransformerBlocks(nn.Module):
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
if (self.residual_diff_threshold_multi < 0.0) or (batch_size > 1):
if batch_size > 1 and self.verbose:
print("Batch size > 1 currently not supported")
if self.residual_diff_threshold_multi < 0.0:
hidden_states = self.m.forward(
hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment