Unverified Commit f0c83919 authored by Boynn's avatar Boynn Committed by GitHub
Browse files

feat: add support for teacache in flux kontext (#618)

* feat: add support for teacache in flux kontext

* merge main and make linter happy
parent 882aa077
import time
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
)
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
).convert("RGB")
prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
start_time = time.time()
with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True, model_name="flux-kontext"):
image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
end_time = time.time()
print(f"Time taken: {(end_time - start_time)} seconds")
image.save(f"flux-kontext-dev-{get_precision()}-tc.png")
......@@ -37,8 +37,9 @@ Example:
output = model(inputs_for_step)
Note:
The rescaling function uses polynomial coefficients optimized for Flux models:
[4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01]
The rescaling function uses polynomial coefficients optimized for Flux and Flux-Kontext models:
Flux: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01]
Flux-Kontext: [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02]
"""
from types import MethodType
......@@ -58,7 +59,15 @@ from ..models.transformers import NunchakuFluxTransformer2dModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_steps: int = 0) -> Callable:
coefficients_by_model = {
"flux": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
"flux-kontext": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02],
}
def make_teacache_forward(
num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_steps: int = 0, model_name: str = "flux"
) -> Callable:
"""
Create a cached forward method for Flux transformers using TeaCache.
......@@ -180,21 +189,19 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [
4.98651651e02,
-2.83781631e02,
5.58554382e01,
-3.82021401e00,
2.64230861e-01,
]
coefficients = coefficients_by_model[model_name]
if coefficients is None:
raise ValueError(f"No coefficients found for model {model_name}")
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
self.accumulated_rel_l1_distance += np.abs(
rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False
......@@ -315,6 +322,7 @@ class TeaCache:
Useful for model stabilization. Defaults to 0.
enabled (bool, optional): Whether caching is enabled. If False, the model
behaves normally. Defaults to True.
model_name (str, optional): Name of the model to use for caching. Defaults to "flux". It supports "flux" and "flux-kontext".
Attributes:
model: Reference to the transformer model
......@@ -323,6 +331,7 @@ class TeaCache:
skip_steps (int): Number of steps to skip caching
enabled (bool): Caching enabled flag
previous_model_forward: Original forward method (for restoration)
model_name: Name of the model to use for caching. Defaults to "flux". It supports "flux" and "flux-kontext".
Example:
Basic usage::
......@@ -349,6 +358,7 @@ class TeaCache:
rel_l1_thresh: float = 0.6,
skip_steps: int = 0,
enabled: bool = True,
model_name: str = "flux",
) -> None:
self.model = model
self.num_steps = num_steps
......@@ -356,6 +366,7 @@ class TeaCache:
self.skip_steps = skip_steps
self.enabled = enabled
self.previous_model_forward = self.model.forward
self.model_name = model_name
def __enter__(self) -> "TeaCache":
"""
......@@ -374,7 +385,7 @@ class TeaCache:
if self.enabled:
# self.model.__class__.forward = make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps) # type: ignore
self.model.forward = MethodType(
make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps), self.model
make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps, self.model_name), self.model
)
self.model.cnt = 0
self.model.accumulated_rel_l1_distance = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment