Unverified Commit f0c83919 authored by Boynn's avatar Boynn Committed by GitHub
Browse files

feat: add support for teacache in flux kontext (#618)

* feat: add support for teacache in flux kontext

* merge main and make linter happy
parent 882aa077
import time
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
)
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
).convert("RGB")
prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
start_time = time.time()
with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True, model_name="flux-kontext"):
image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
end_time = time.time()
print(f"Time taken: {(end_time - start_time)} seconds")
image.save(f"flux-kontext-dev-{get_precision()}-tc.png")
...@@ -37,8 +37,9 @@ Example: ...@@ -37,8 +37,9 @@ Example:
output = model(inputs_for_step) output = model(inputs_for_step)
Note: Note:
The rescaling function uses polynomial coefficients optimized for Flux models: The rescaling function uses polynomial coefficients optimized for Flux and Flux-Kontext models:
[4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01] Flux: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01]
Flux-Kontext: [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02]
""" """
from types import MethodType from types import MethodType
...@@ -58,7 +59,15 @@ from ..models.transformers import NunchakuFluxTransformer2dModel ...@@ -58,7 +59,15 @@ from ..models.transformers import NunchakuFluxTransformer2dModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_steps: int = 0) -> Callable: coefficients_by_model = {
"flux": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
"flux-kontext": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02],
}
def make_teacache_forward(
num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_steps: int = 0, model_name: str = "flux"
) -> Callable:
""" """
Create a cached forward method for Flux transformers using TeaCache. Create a cached forward method for Flux transformers using TeaCache.
...@@ -180,21 +189,19 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_ ...@@ -180,21 +189,19 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
should_calc = True should_calc = True
self.accumulated_rel_l1_distance = 0 self.accumulated_rel_l1_distance = 0
else: else:
coefficients = [ coefficients = coefficients_by_model[model_name]
4.98651651e02, if coefficients is None:
-2.83781631e02, raise ValueError(f"No coefficients found for model {model_name}")
5.58554382e01,
-3.82021401e00,
2.64230861e-01,
]
rescale_func = np.poly1d(coefficients) rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func( self.accumulated_rel_l1_distance += np.abs(
( rescale_func(
(modulated_inp - self.previous_modulated_input).abs().mean() (
/ self.previous_modulated_input.abs().mean() (modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
) )
.cpu()
.item()
) )
if self.accumulated_rel_l1_distance < rel_l1_thresh: if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False should_calc = False
...@@ -315,6 +322,7 @@ class TeaCache: ...@@ -315,6 +322,7 @@ class TeaCache:
Useful for model stabilization. Defaults to 0. Useful for model stabilization. Defaults to 0.
enabled (bool, optional): Whether caching is enabled. If False, the model enabled (bool, optional): Whether caching is enabled. If False, the model
behaves normally. Defaults to True. behaves normally. Defaults to True.
model_name (str, optional): Name of the model to use for caching. Defaults to "flux". It supports "flux" and "flux-kontext".
Attributes: Attributes:
model: Reference to the transformer model model: Reference to the transformer model
...@@ -323,6 +331,7 @@ class TeaCache: ...@@ -323,6 +331,7 @@ class TeaCache:
skip_steps (int): Number of steps to skip caching skip_steps (int): Number of steps to skip caching
enabled (bool): Caching enabled flag enabled (bool): Caching enabled flag
previous_model_forward: Original forward method (for restoration) previous_model_forward: Original forward method (for restoration)
model_name: Name of the model to use for caching. Defaults to "flux". It supports "flux" and "flux-kontext".
Example: Example:
Basic usage:: Basic usage::
...@@ -349,6 +358,7 @@ class TeaCache: ...@@ -349,6 +358,7 @@ class TeaCache:
rel_l1_thresh: float = 0.6, rel_l1_thresh: float = 0.6,
skip_steps: int = 0, skip_steps: int = 0,
enabled: bool = True, enabled: bool = True,
model_name: str = "flux",
) -> None: ) -> None:
self.model = model self.model = model
self.num_steps = num_steps self.num_steps = num_steps
...@@ -356,6 +366,7 @@ class TeaCache: ...@@ -356,6 +366,7 @@ class TeaCache:
self.skip_steps = skip_steps self.skip_steps = skip_steps
self.enabled = enabled self.enabled = enabled
self.previous_model_forward = self.model.forward self.previous_model_forward = self.model.forward
self.model_name = model_name
def __enter__(self) -> "TeaCache": def __enter__(self) -> "TeaCache":
""" """
...@@ -374,7 +385,7 @@ class TeaCache: ...@@ -374,7 +385,7 @@ class TeaCache:
if self.enabled: if self.enabled:
# self.model.__class__.forward = make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps) # type: ignore # self.model.__class__.forward = make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps) # type: ignore
self.model.forward = MethodType( self.model.forward = MethodType(
make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps), self.model make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps, self.model_name), self.model
) )
self.model.cnt = 0 self.model.cnt = 0
self.model.accumulated_rel_l1_distance = 0 self.model.accumulated_rel_l1_distance = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment