Commit b3f12860 authored by Bluear7878's avatar Bluear7878
Browse files

[Auto Sync] feat: double FB cache + adaptive mechanisms (#76)

* DoubleFBCache

* rename > DoubleFBCache to use_double_fb_cache
parent e4f8ae9b
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
image = pipeline(
["A cat holding a sign that says hello world"],
num_inference_steps=50
).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
......@@ -7,16 +7,28 @@ from torch import nn
from ...caching import utils
def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12):
def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
use_double_fb_cache: bool = False,
residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1,
):
if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold
if getattr(transformer, "_is_cached", False):
transformer.cached_transformer_blocks[0].update_threshold(residual_diff_threshold)
transformer.cached_transformer_blocks[0].update_residual_diff_threshold(use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single)
return transformer
cached_transformer_blocks = nn.ModuleList(
[
utils.FluxCachedTransformerBlocks(
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
return_hidden_states_first=False,
)
]
......
......@@ -3,11 +3,15 @@
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict, Optional
from typing import DefaultDict, Dict, Optional, Tuple
import torch
from torch import nn
from nunchaku.models.transformers.utils import pad_tensor
num_transformer_blocks = 19 # FIXME
num_single_transformer_blocks = 38 # FIXME
@dataclasses.dataclass
class CacheContext:
......@@ -75,38 +79,127 @@ def cache_context(cache_context):
def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = mean_diff / mean_t1
return diff.item() < threshold
diff = (mean_diff / mean_t1).item()
return diff < threshold, diff
@torch.compiler.disable
def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states_residual = get_buffer("hidden_states_residual")
assert hidden_states_residual is not None, "hidden_states_residual must be set before"
hidden_states = hidden_states_residual + hidden_states
hidden_states = hidden_states.contiguous()
if encoder_hidden_states is not None:
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()
hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, mode: str = "multi",
) -> Tuple[torch.Tensor, torch.Tensor]:
if mode == "multi":
hidden_states_residual = get_buffer("multi_hidden_states_residual")
assert hidden_states_residual is not None, (
"multi_hidden_states_residual must be set before"
)
hidden_states = hidden_states + hidden_states_residual
hidden_states = hidden_states.contiguous()
return hidden_states, encoder_hidden_states
if encoder_hidden_states is not None:
enc_hidden_res = get_buffer("multi_encoder_hidden_states_residual")
msg = "multi_encoder_hidden_states_residual must be set before"
assert enc_hidden_res is not None, msg
encoder_hidden_states = encoder_hidden_states + enc_hidden_res
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
elif mode == "single":
single_residual = get_buffer("single_hidden_states_residual")
msg = "single_hidden_states_residual must be set before"
assert single_residual is not None, msg
hidden_states = hidden_states + single_residual
hidden_states = hidden_states.contiguous()
return hidden_states
else:
raise ValueError(f"Unknown mode {mode}; expected 'multi' or 'single'")
@torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
prev_first_hidden_states_residual,
def get_can_use_cache(first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"):
if mode == "multi":
buffer_name = "first_multi_hidden_states_residual"
elif mode == "single":
buffer_name = "first_single_hidden_states_residual"
else:
raise ValueError(f"Unknown mode {mode}; expected 'multi' or 'single'")
prev_res = get_buffer(buffer_name)
if prev_res is None:
return False, threshold
is_similar, diff = are_two_tensors_similar(
prev_res,
first_hidden_states_residual,
threshold=threshold,
parallelized=parallelized,
)
return can_use_cache
return is_similar, diff
def check_and_apply_cache(
*,
first_residual: torch.Tensor,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
threshold: float,
parallelized: bool,
mode: str,
verbose: bool,
call_remaining_fn,
remaining_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
can_use_cache, diff = get_can_use_cache(
first_residual,
threshold=threshold,
parallelized=parallelized,
mode=mode,
)
torch._dynamo.graph_break()
if can_use_cache:
if verbose:
print(f"[{mode.upper()}] Cache hit! diff={diff:.4f}, "
f"new threshold={threshold:.4f}")
out = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states, mode=mode
)
updated_h, updated_enc = out if isinstance(out, tuple) else (out, None)
return updated_h, updated_enc, threshold
old_threshold = threshold
if verbose:
print(f"[{mode.upper()}] Cache miss. diff={diff:.4f}, "
f"was={old_threshold:.4f} => now={threshold:.4f}")
if mode == "multi":
set_buffer("first_multi_hidden_states_residual", first_residual)
else:
set_buffer("first_single_hidden_states_residual", first_residual)
result = call_remaining_fn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
**remaining_kwargs
)
if mode == "multi":
updated_h, updated_enc, hs_res, enc_res = result
set_buffer("multi_hidden_states_residual", hs_res)
set_buffer("multi_encoder_hidden_states_residual", enc_res)
return updated_h, updated_enc, threshold
elif mode == "single":
updated_cat_states, cat_res = result
set_buffer("single_hidden_states_residual", cat_res)
return updated_cat_states, None, threshold
raise ValueError(f"Unknown mode {mode}")
class SanaCachedTransformerBlocks(nn.Module):
......@@ -230,109 +323,339 @@ class FluxCachedTransformerBlocks(nn.Module):
def __init__(
self,
*,
transformer=None,
residual_diff_threshold,
return_hidden_states_first=True,
return_hidden_states_only=False,
transformer: nn.Module = None,
use_double_fb_cache: bool = True,
residual_diff_threshold_multi: float,
residual_diff_threshold_single: float,
return_hidden_states_first: bool = True,
return_hidden_states_only: bool = False,
verbose: bool = False,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks
self.single_transformer_blocks = transformer.single_transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.use_double_fb_cache = use_double_fb_cache
self.residual_diff_threshold_multi = residual_diff_threshold_multi
self.residual_diff_threshold_single = residual_diff_threshold_single
self.return_hidden_states_first = return_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only
self.verbose = verbose
def update_residual_diff_threshold(self, residual_diff_threshold=0.12):
self.residual_diff_threshold = residual_diff_threshold
self.m = self.transformer_blocks[0].m
self.dtype = torch.bfloat16 if self.m.isBF16() else torch.float16
self.device = transformer.device
@staticmethod
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
assert rotemb.dtype == torch.float32
B = rotemb.shape[0]
M = rotemb.shape[1]
D = rotemb.shape[2] * 2
msg_shape = "rotemb shape must be (B, M, D//2, 1, 2)"
assert rotemb.shape == (B, M, D // 2, 1, 2), msg_shape
assert M % 16 == 0
assert D % 8 == 0
rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
rotemb = rotemb.permute(0, 1, 3, 2, 4)
# 16*8 pack, FP32 accumulator (C) format
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
rotemb = rotemb.contiguous()
rotemb = rotemb.view(B, M, D)
return rotemb
def update_residual_diff_threshold(
self,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.12,
residual_diff_threshold_single=0.09
):
self.use_double_fb_cache = use_double_fb_cache
self.residual_diff_threshold_multi = residual_diff_threshold_multi
self.residual_diff_threshold_single = residual_diff_threshold_single
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 1:
if batch_size > 1:
print("Batch size > 1 currently not supported")
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, *args, **kwargs
)
original_dtype = hidden_states.dtype
original_device = hidden_states.device
hidden_states = hidden_states.to(self.dtype).to(self.device)
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
return (
hidden_states
if self.return_hidden_states_only
else (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
controlnet_block_samples = (
torch.stack(controlnet_block_samples).to(self.device) if len(controlnet_block_samples) > 0 else None
)
if controlnet_single_block_samples is not None and len(controlnet_single_block_samples) > 0:
controlnet_single_block_samples = (
torch.stack(controlnet_single_block_samples).to(self.device)
if len(controlnet_single_block_samples) > 0
else None
)
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block.forward_layer_at(
0, hidden_states, encoder_hidden_states, *args, **kwargs
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
# [1, tokens, head_dim/2, 1, 2] (sincos)
total_tokens = txt_tokens + img_tokens
assert image_rotary_emb.shape[2] == 1 * total_tokens
image_rotary_emb = image_rotary_emb.reshape(
[1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]
)
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...]
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...]
rotary_emb_single = image_rotary_emb
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
if (self.residual_diff_threshold_multi < 0.0) or (batch_size > 1):
if batch_size > 1 and self.verbose:
print("Batch size > 1 currently not supported")
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
if self.verbose:
print("Cache hit!!!")
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states
)
else:
if self.verbose:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
(
hidden_states = self.m.forward(
hidden_states,
encoder_hidden_states,
hidden_states_residual,
encoder_hidden_states_residual,
) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
temb,
rotary_emb_img,
rotary_emb_txt,
rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer,
)
hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
hidden_states = hidden_states[:, txt_tokens:, ...]
if self.return_hidden_states_only:
return hidden_states
if self.return_hidden_states_first:
return hidden_states, encoder_hidden_states
return encoder_hidden_states, hidden_states
remaining_kwargs = {
"temb": temb,
"rotary_emb_img": rotary_emb_img,
"rotary_emb_txt": rotary_emb_txt,
"rotary_emb_single": rotary_emb_single,
"controlnet_block_samples": controlnet_block_samples,
"controlnet_single_block_samples": controlnet_single_block_samples,
"txt_tokens": txt_tokens,
}
original_hidden_states = hidden_states
first_hidden_states, first_encoder_hidden_states = self.m.forward_layer(
0,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
controlnet_block_samples,
controlnet_single_block_samples,
)
hidden_states = first_hidden_states
encoder_hidden_states = first_encoder_hidden_states
first_hidden_states_residual_multi = hidden_states - original_hidden_states
del original_hidden_states
if self.use_double_fb_cache:
call_remaining_fn = self.call_remaining_multi_transformer_blocks
else:
call_remaining_fn = self.call_remaining_FBCache_transformer_blocks
torch._dynamo.graph_break()
updated_h, updated_enc, threshold = check_and_apply_cache(
first_residual=first_hidden_states_residual_multi,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
threshold=self.residual_diff_threshold_multi,
parallelized=(
self.transformer is not None
and getattr(self.transformer, "_is_parallelized", False)
),
mode="multi",
verbose=self.verbose,
call_remaining_fn=call_remaining_fn,
remaining_kwargs=remaining_kwargs,
)
self.residual_diff_threshold_multi = threshold
if not self.use_double_fb_cache:
if self.return_hidden_states_only:
return updated_h
if self.return_hidden_states_first:
return updated_h, updated_enc
return updated_enc, updated_h
# DoubleFBCache
cat_hidden_states = torch.cat([updated_enc, updated_h], dim=1)
original_cat = cat_hidden_states
cat_hidden_states = self.m.forward_single_layer(
0, cat_hidden_states, temb, rotary_emb_single
)
return (
hidden_states
if self.return_hidden_states_only
else (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
first_hidden_states_residual_single = cat_hidden_states - original_cat
del original_cat
call_remaining_fn_single = self.call_remaining_single_transformer_blocks
updated_cat, _, threshold = check_and_apply_cache(
first_residual=first_hidden_states_residual_single,
hidden_states=cat_hidden_states,
encoder_hidden_states=None,
threshold=self.residual_diff_threshold_single,
parallelized=(
self.transformer is not None
and getattr(self.transformer, "_is_parallelized", False)
),
mode="single",
verbose=self.verbose,
call_remaining_fn=call_remaining_fn_single,
remaining_kwargs=remaining_kwargs,
)
self.residual_diff_threshold_single = threshold
def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
first_transformer_block = self.transformer_blocks[0]
# torch._dynamo.graph_break()
final_enc = updated_cat[:, :txt_tokens, ...]
final_h = updated_cat[:, txt_tokens:, ...]
final_h = final_h.to(original_dtype).to(original_device)
final_enc = final_enc.to(original_dtype).to(original_device)
if self.return_hidden_states_only:
return final_h
if self.return_hidden_states_first:
return final_h, final_enc
return final_enc, final_h
def call_remaining_FBCache_transformer_blocks(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
rotary_emb_img: torch.Tensor,
rotary_emb_txt: torch.Tensor,
rotary_emb_single: torch.Tensor,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=True,
txt_tokens=None,
):
original_dtype = hidden_states.dtype
original_device = hidden_states.device
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
encoder_hidden_states, hidden_states = first_transformer_block.forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
skip_first_layer=True,
*args,
**kwargs,
hidden_states = self.m.forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer,
)
hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
hidden_states = hidden_states[:, txt_tokens:, ...]
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
hidden_states_residual = hidden_states - original_hidden_states
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
enc_residual = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
return hidden_states, encoder_hidden_states, hidden_states_residual, enc_residual
def call_remaining_multi_transformer_blocks(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
rotary_emb_img: torch.Tensor,
rotary_emb_txt: torch.Tensor,
rotary_emb_single: torch.Tensor,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
txt_tokens=None,
):
start_idx = 1
original_hidden_states = hidden_states.clone()
original_encoder_hidden_states = encoder_hidden_states.clone()
for idx in range(start_idx, num_transformer_blocks):
hidden_states, encoder_hidden_states = self.m.forward_layer(
idx,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
controlnet_block_samples,
controlnet_single_block_samples,
)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
hs_res = hidden_states - original_hidden_states
enc_res = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hs_res, enc_res
def call_remaining_single_transformer_blocks(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
rotary_emb_img: torch.Tensor,
rotary_emb_txt: torch.Tensor,
rotary_emb_single: torch.Tensor,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
txt_tokens=None,
):
start_idx = 1
original_hidden_states = hidden_states.clone()
for idx in range(start_idx, num_single_transformer_blocks):
hidden_states = self.m.forward_single_layer(
idx,
hidden_states,
temb,
rotary_emb_single,
)
hidden_states = hidden_states.contiguous()
hs_res = hidden_states - original_hidden_states
return hidden_states, hs_res
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.144),
(True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),],
)
def test_flux_dev_cache(
use_double_fb_cache: bool,
residual_diff_threshold_multi : float,
residual_diff_threshold_single : float,
height: int,
width: int,
num_inference_steps: int,
lora_name: str,
lora_strength: float,
expected_lpips: float,
):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ" if lora_name is None else lora_name,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_names=lora_name,
lora_strengths=lora_strength,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
expected_lpips=expected_lpips,
)
......@@ -13,6 +13,7 @@ from tqdm import tqdm
import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
......@@ -141,6 +142,9 @@ def run_test(
attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16"
cpu_offload: bool = False,
cache_threshold: float = 0,
use_double_fb_cache: bool = False,
residual_diff_threshold_multi : float = 0,
residual_diff_threshold_single : float = 0,
lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 4,
......@@ -259,6 +263,12 @@ def run_test(
precision_str += "-co"
if cache_threshold > 0:
precision_str += f"-cache{cache_threshold}"
if use_double_fb_cache:
precision_str += "-dfb"
if residual_diff_threshold_multi > 0:
precision_str += f"-rdm{residual_diff_threshold_multi}"
if residual_diff_threshold_single > 0:
precision_str += f"-rds{residual_diff_threshold_single}"
if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}"
if batch_size > 1:
......@@ -303,6 +313,14 @@ def run_test(
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
if use_double_fb_cache:
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single)
run_pipeline(
batch_size=batch_size,
dataset=dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment