Unverified Commit 37a27712 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #340 from mit-han-lab/dev

feat: support PuLID, Double FBCache and TeaCache; better linter
parents c1d6fc84 760ab022
...@@ -9,13 +9,13 @@ import GPUtil ...@@ -9,13 +9,13 @@ import GPUtil
import spaces import spaces
import torch import torch
from peft.tuners import lora from peft.tuners import lora
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
...@@ -84,7 +84,7 @@ def generate( ...@@ -84,7 +84,7 @@ def generate(
images, latency_strs = [], [] images, latency_strs = [], []
for i, pipeline in enumerate(pipelines): for i, pipeline in enumerate(pipelines):
precision = args.precisions[i] precision = args.precisions[i]
progress = gr.Progress(track_tqdm=True) gr.Progress(track_tqdm=True)
if pipeline.cur_lora_name != lora_name: if pipeline.cur_lora_name != lora_name:
if precision == "bf16": if precision == "bf16":
for m in pipeline.transformer.modules(): for m in pipeline.transformer.modules():
...@@ -164,7 +164,7 @@ if len(gpus) > 0: ...@@ -164,7 +164,7 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks( with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"], css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
......
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from peft.tuners import lora from peft.tuners import lora
from vars import LORA_PATHS, SVDQ_LORA_PATHS
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from vars import LORA_PATHS, SVDQ_LORA_PATHS
def hash_str_to_int(s: str) -> int: def hash_str_to_int(s: str) -> int:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo" alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/> style="height: 40px; width: auto; display: block; margin: auto;"/>
<a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo <a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo
......
...@@ -2,7 +2,6 @@ import argparse ...@@ -2,7 +2,6 @@ import argparse
import os import os
import torch import torch
from utils import get_pipeline from utils import get_pipeline
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import torch import torch
from torch import nn from torch import nn
from tqdm import trange from tqdm import trange
from utils import get_pipeline from utils import get_pipeline
......
...@@ -8,13 +8,13 @@ from datetime import datetime ...@@ -8,13 +8,13 @@ from datetime import datetime
import GPUtil import GPUtil
import spaces import spaces
import torch import torch
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import EXAMPLES, MAX_SEED from vars import EXAMPLES, MAX_SEED
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
...@@ -73,7 +73,7 @@ def generate( ...@@ -73,7 +73,7 @@ def generate(
prompt = "A peaceful world." prompt = "A peaceful world."
images, latency_strs = [], [] images, latency_strs = [], []
for i, pipeline in enumerate(pipelines): for i, pipeline in enumerate(pipelines):
progress = gr.Progress(track_tqdm=True) gr.Progress(track_tqdm=True)
start_time = time.time() start_time = time.time()
image = pipeline( image = pipeline(
prompt=prompt, prompt=prompt,
...@@ -124,11 +124,11 @@ if len(gpus) > 0: ...@@ -124,11 +124,11 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks( with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"], css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
title=f"SVDQuant SANA-1600M Demo", title="SVDQuant SANA-1600M Demo",
) as demo: ) as demo:
def get_header_str(): def get_header_str():
......
...@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel ...@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters.flux import apply_cache_on_pipe
from nunchaku.utils import get_gpu_memory, get_precision from nunchaku.utils import get_gpu_memory, get_precision
base_model = "black-forest-labs/FLUX.1-dev" base_model = "black-forest-labs/FLUX.1-dev"
...@@ -29,11 +28,6 @@ if need_offload: ...@@ -29,11 +28,6 @@ if need_offload:
else: else:
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
# apply_cache_on_pipe(
# pipeline, residual_diff_threshold=0.1
# ) # Uncomment this line to enable first-block cache to speedup generation
prompt = "A anime style girl with messy beach waves." prompt = "A anime style girl with messy beach waves."
control_image_depth = load_image( control_image_depth = load_image(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg" "https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
from types import MethodType
import torch
from diffusers.utils import load_image
from nunchaku.models.pulid.pulid_forward import pulid_forward
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
id_image = load_image("https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true")
image = pipeline(
"A woman holding a sign that says 'SVDQuant is fast!",
id_image=id_image,
id_weight=1,
num_inference_steps=12,
guidance_scale=3.5,
).images[0]
image.save("flux.1-dev-pulid.png")
import time
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
start_time = time.time()
with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True):
image = pipeline(
"A cat holding a sign that says hello world",
num_inference_steps=50,
guidance_scale=3.5,
height=1024,
width=1024,
generator=torch.Generator(device="cuda").manual_seed(0),
).images[0]
end_time = time.time()
print(f"Time taken: {(end_time - start_time)} seconds")
image.save(f"flux.1-dev-{precision}-tc.png")
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
__version__ = "0.3.0dev0" __version__ = "0.3.0dev1"
...@@ -7,16 +7,30 @@ from torch import nn ...@@ -7,16 +7,30 @@ from torch import nn
from ...caching import utils from ...caching import utils
def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12): def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
use_double_fb_cache: bool = False,
residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1,
):
if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold
if getattr(transformer, "_is_cached", False): if getattr(transformer, "_is_cached", False):
transformer.cached_transformer_blocks[0].update_threshold(residual_diff_threshold) transformer.cached_transformer_blocks[0].update_residual_diff_threshold(
use_double_fb_cache, residual_diff_threshold_multi, residual_diff_threshold_single
)
return transformer return transformer
cached_transformer_blocks = nn.ModuleList( cached_transformer_blocks = nn.ModuleList(
[ [
utils.FluxCachedTransformerBlocks( utils.FluxCachedTransformerBlocks(
transformer=transformer, transformer=transformer,
residual_diff_threshold=residual_diff_threshold, use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
return_hidden_states_first=False, return_hidden_states_first=False,
) )
] ]
......
from types import MethodType
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.utils import logging
from diffusers.utils.constants import USE_PEFT_BACKEND
from diffusers.utils.import_utils import is_torch_version
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
from ..models.transformers import NunchakuFluxTransformer2dModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_steps: int = 0) -> Callable:
def teacache_forward(
self: Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel],
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_projections: torch.Tensor,
timestep: torch.LongTensor,
img_ids: torch.Tensor,
txt_ids: torch.Tensor,
guidance: torch.Tensor,
joint_attention_kwargs: Optional[dict[str, Any]] = None,
controlnet_block_samples: Optional[torch.Tensor] = None,
controlnet_single_block_samples: Optional[torch.Tensor] = None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000 # type: ignore
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) # type: ignore
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_) # type: ignore
if self.cnt == 0 or self.cnt == num_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [
4.98651651e02,
-2.83781631e02,
5.58554382e01,
-3.82021401e00,
2.64230861e-01,
]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == num_steps:
self.cnt = 0
ckpt_kwargs: dict[str, Any]
if self.cnt > skip_steps:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
self.previous_residual = hidden_states - ori_hidden_states
else:
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output: torch.FloatTensor = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return output
return Transformer2DModelOutput(sample=output)
return teacache_forward
# A context manager to add teacache support to a block of code
# When the context manager is applied, the model passed to the context manager is modified
# to support teacache
class TeaCache:
def __init__(
self,
model: Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel],
num_steps: int = 50,
rel_l1_thresh: float = 0.6,
skip_steps: int = 0,
enabled: bool = True,
) -> None:
self.model = model
self.num_steps = num_steps
self.rel_l1_thresh = rel_l1_thresh
self.skip_steps = skip_steps
self.enabled = enabled
self.previous_model_forward = self.model.forward
def __enter__(self) -> "TeaCache":
if self.enabled:
# self.model.__class__.forward = make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps) # type: ignore
self.model.forward = MethodType(
make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps), self.model
)
self.model.cnt = 0
self.model.accumulated_rel_l1_distance = 0
self.model.previous_modulated_input = None
self.model.previous_residual = None
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.enabled:
self.model.forward = self.previous_model_forward
del self.model.cnt
del self.model.accumulated_rel_l1_distance
del self.model.previous_modulated_input
del self.model.previous_residual
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment