Commit 4eec372d authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support hunyuan offload and teacache. (#15)



* Support hunyuan offload and teacache.

* Fix

* Fix

---------
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent 8b76a6a5
import argparse
from contextlib import contextmanager
import torch
import torch.distributed as dist
import os
......@@ -9,6 +8,7 @@ import json
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
from contextlib import contextmanager
from PIL import Image
from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel
......@@ -16,15 +16,14 @@ from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel
from lightx2v.text2v.models.text_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.text2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerFeatureCaching
from lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.text2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerFeatureCaching
from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerTeaCaching
from lightx2v.text2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.text2v.models.networks.wan.model import WanModel
from lightx2v.text2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
......@@ -34,8 +33,10 @@ from lightx2v.image2v.models.wan.model import CLIPModel
@contextmanager
def time_duration(label: str = ""):
torch.cuda.synchronize()
start_time = time.time()
yield
torch.cuda.synchronize()
end_time = time.time()
print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds")
......@@ -61,16 +62,16 @@ def load_models(args, model_config):
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device, args=args)
elif args.model_cls == "wan2.1":
text_encoder = T5EncoderModel(
text_len=model_config["text_len"],
dtype=torch.bfloat16,
device=init_device,
checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
with time_duration("Load Text Encoder"):
text_encoder = T5EncoderModel(
text_len=model_config["text_len"],
dtype=torch.bfloat16,
device=init_device,
checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
with time_duration("Load Wan Model"):
model = WanModel(args.model_path, model_config, init_device)
......@@ -256,8 +257,10 @@ def init_scheduler(args, image_encoder_output):
if args.model_cls == "hunyuan":
if args.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(args, image_encoder_output)
elif args.feature_caching == "Tea":
scheduler = HunyuanSchedulerTeaCaching(args, image_encoder_output)
elif args.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerFeatureCaching(args, image_encoder_output)
scheduler = HunyuanSchedulerTaylorCaching(args, image_encoder_output)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
......@@ -265,7 +268,7 @@ def init_scheduler(args, image_encoder_output):
if args.feature_caching == "NoCaching":
scheduler = WanScheduler(args)
elif args.feature_caching == "Tea":
scheduler = WanSchedulerFeatureCaching(args)
scheduler = WanSchedulerTeaCaching(args)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
......@@ -338,7 +341,6 @@ if __name__ == "__main__":
parser.add_argument("--use_bfloat16", action="store_true", default=True)
parser.add_argument("--lora_path", type=str, default=None)
parser.add_argument("--strength_model", type=float, default=1.0)
args = parser.parse_args()
start_time = time.time()
......@@ -383,7 +385,8 @@ if __name__ == "__main__":
else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config, image_encoder_output)
with time_duration("Run Text Encoder"):
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config, image_encoder_output)
set_target_shape(args, image_encoder_output)
scheduler = init_scheduler(args, image_encoder_output)
......@@ -399,16 +402,15 @@ if __name__ == "__main__":
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache()
images = run_vae(latents, generator, args)
with time_duration("Run VAE"):
images = run_vae(latents, generator, args)
if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0):
save_video_st = time.time()
if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, args.save_video_path, fps=24)
save_video_et = time.time()
print(f"Save video cost: {save_video_et - save_video_st}")
with time_duration("Save video"):
if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, args.save_video_path, fps=24)
end_time = time.time()
print(f"Total cost: {end_time - start_time}")
......@@ -36,11 +36,15 @@ class MMWeightTemplate(metaclass=ABCMeta):
def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking)
......@@ -109,7 +113,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load_int8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32)
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
......@@ -245,7 +249,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = Q8F.linear.fp8_linear(input_tensor_quant, self.weight, self.bias, input_tensor_scale, self.weight_scale, out_dtype=torch.bfloat16)
output_tensor = Q8F.linear.fp8_linear(input_tensor_quant, self.weight, self.bias.float(), input_tensor_scale, self.weight_scale, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0)
......@@ -268,7 +272,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = Q8F.linear.q8_linear(input_tensor_quant, self.weight, self.bias, input_tensor_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
output_tensor = Q8F.linear.q8_linear(input_tensor_quant, self.weight, self.bias.float(), input_tensor_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0)
......
import torch
import numpy as np
from einops import rearrange
from lightx2v.attentions import attention
from .utils import taylor_cache_init, derivative_approximation, taylor_formula
from ..utils_bf16 import apply_rotary_emb
from typing import Dict
import math
from ..transformer_infer import HunyuanTransformerInfer
def taylor_cache_init(cache_dic: Dict, current: Dict):
"""
Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if current["step"] == 0:
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = {}
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance = current["activated_steps"][-1] - current["activated_steps"][-2]
# difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors = {}
updated_taylor_factors[0] = feature
class HunyuanTransformerInferTeaCaching(HunyuanTransformerInfer):
def __init__(self, config):
super().__init__(config)
for i in range(cache_dic["max_order"]):
if (cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]].get(i, None) is not None) and (current["step"] > cache_dic["first_enhance"] - 2):
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i]) / difference_distance
def infer(
self,
weights,
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec=None,
frist_frame_token_num=None,
):
inp = img.clone()
vec_ = vec.clone()
weights.double_blocks_weights[0].to_cuda()
img_mod1_shift, img_mod1_scale, _, _, _, _ = weights.double_blocks_weights[0].img_mod.apply(vec_).chunk(6, dim=-1)
weights.double_blocks_weights[0].to_cpu_sync()
normed_inp = torch.nn.functional.layer_norm(inp, (inp.shape[1],), None, None, 1e-6)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
del normed_inp, inp, vec_
if self.scheduler.cnt == 0 or self.scheduler.cnt == self.scheduler.num_steps - 1:
should_calc = True
self.scheduler.accumulated_rel_l1_distance = 0
else:
break
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = updated_taylor_factors
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
"""
Compute Taylor expansion error
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x = current["step"] - current["activated_steps"][-1]
# x = current['t'] - current['activated_times'][-1]
output = 0
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance += rescale_func(
((modulated_inp - self.scheduler.previous_modulated_input).abs().mean() / self.scheduler.previous_modulated_input.abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance < self.scheduler.teacache_thresh:
should_calc = False
else:
should_calc = True
self.scheduler.accumulated_rel_l1_distance = 0
self.scheduler.previous_modulated_input = modulated_inp
del modulated_inp
for i in range(len(cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]])):
output += (1 / math.factorial(i)) * cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i] * (x**i)
if not should_calc:
img += self.scheduler.previous_residual
else:
ori_img = img.clone()
img, vec = super().infer(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
self.scheduler.previous_residual = img - ori_img
del ori_img
torch.cuda.empty_cache()
return output
return img, vec
class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer):
def __init__(self, config):
super().__init__(config)
assert not self.config["cpu_offload"], "Not support cpu-offload for TaylorCaching"
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
self.scheduler.current["stream"] = "double_stream"
for i in range(self.double_blocks_num):
self.scheduler.current["layer"] = i
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
x = torch.cat((img, txt), 0)
self.scheduler.current["stream"] = "single_stream"
for i in range(self.single_blocks_num):
self.scheduler.current["layer"] = i
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...]
return img, vec
......@@ -133,8 +142,24 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
)
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate)
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate)
img = self.infer_double_block_img_post_atten(
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
)
txt = self.infer_double_block_txt_post_atten(
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
return img, txt
elif self.scheduler.current["type"] == "taylor_cache":
......@@ -166,7 +191,16 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
return img, txt
def infer_double_block_img_post_atten(self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate):
def infer_double_block_img_post_atten(
self,
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
):
self.scheduler.current["module"] = "img_attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
......@@ -190,7 +224,16 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
img = img + out
return img
def infer_double_block_txt_post_atten(self, weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate):
def infer_double_block_txt_post_atten(
self,
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
):
self.scheduler.current["module"] = "txt_attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
......
from typing import Dict
import math
import torch
def taylor_cache_init(cache_dic: Dict, current: Dict):
"""
Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if current["step"] == 0:
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = {}
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance = current["activated_steps"][-1] - current["activated_steps"][-2]
# difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors = {}
updated_taylor_factors[0] = feature
for i in range(cache_dic["max_order"]):
if (cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]].get(i, None) is not None) and (current["step"] > cache_dic["first_enhance"] - 2):
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i]) / difference_distance
else:
break
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = updated_taylor_factors
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
"""
Compute Taylor expansion error
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x = current["step"] - current["activated_steps"][-1]
# x = current['t'] - current['activated_times'][-1]
output = 0
for i in range(len(cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]])):
output += (1 / math.factorial(i)) * cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i] * (x**i)
return output
......@@ -38,15 +38,7 @@ class HunyuanTransformerInfer:
self.double_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream):
img, txt = self.infer_double_block(
self.double_weights_stream_mgr.active_weights[0],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
img, txt = self.infer_double_block(self.double_weights_stream_mgr.active_weights[0], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
if double_block_idx < self.double_blocks_num - 1:
self.double_weights_stream_mgr.prefetch_weights(double_block_idx + 1, weights.double_blocks_weights)
......@@ -54,23 +46,21 @@ class HunyuanTransformerInfer:
x = torch.cat((img, txt), 0)
img = img.cpu()
txt = txt.cpu()
del img, txt
torch.cuda.empty_cache()
for single_block_idx in range(self.single_blocks_num):
if single_block_idx == 0:
self.single_weights_stream_mgr.active_weights[0] = weights.single_blocks_weights[0]
self.single_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream):
x = self.infer_single_block(
weights.single_blocks_weights[single_block_idx],
x,
vec,
txt_seq_len,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
x = self.infer_single_block(self.single_weights_stream_mgr.active_weights[0], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
if single_block_idx < self.single_blocks_num - 1:
self.single_weights_stream_mgr.prefetch_weights(single_block_idx + 1, weights.single_blocks_weights)
self.single_weights_stream_mgr.swap_weights()
torch.cuda.empty_cache()
img = x[:img_seq_len, ...]
return img, vec
......
......@@ -6,7 +6,7 @@ from lightx2v.text2v.models.networks.hunyuan.weights.transformer_weights import
from lightx2v.text2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.text2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferFeatureCaching
from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferTaylorCaching, HunyuanTransformerInferTeaCaching
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
......@@ -43,7 +43,9 @@ class HunyuanModel:
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferFeatureCaching
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......@@ -107,3 +109,8 @@ class HunyuanModel:
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps:
self.scheduler.cnt = 0
......@@ -3,7 +3,7 @@ from ..transformer_infer import WanTransformerInfer
import torch
class WanTransformerInferFeatureCaching(WanTransformerInfer):
class WanTransformerInferTeaCaching(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
......
......@@ -12,7 +12,7 @@ from lightx2v.text2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.text2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferFeatureCaching
from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
......@@ -49,7 +49,7 @@ class WanModel:
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferFeatureCaching
self.transformer_infer_class = WanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......
import torch
from .utils import cache_init, cal_type
from ..scheduler import HunyuanScheduler
import torch
def cache_init(num_steps, model_kwargs=None):
"""
Initialization for cache.
"""
cache_dic = {}
cache = {}
cache_index = {}
cache[-1] = {}
cache_index[-1] = {}
cache_index["layer_index"] = {}
cache_dic["attn_map"] = {}
cache_dic["attn_map"][-1] = {}
cache_dic["attn_map"][-1]["double_stream"] = {}
cache_dic["attn_map"][-1]["single_stream"] = {}
cache_dic["k-norm"] = {}
cache_dic["k-norm"][-1] = {}
cache_dic["k-norm"][-1]["double_stream"] = {}
cache_dic["k-norm"][-1]["single_stream"] = {}
cache_dic["v-norm"] = {}
cache_dic["v-norm"][-1] = {}
cache_dic["v-norm"][-1]["double_stream"] = {}
cache_dic["v-norm"][-1]["single_stream"] = {}
cache_dic["cross_attn_map"] = {}
cache_dic["cross_attn_map"][-1] = {}
cache[-1]["double_stream"] = {}
cache[-1]["single_stream"] = {}
cache_dic["cache_counter"] = 0
for j in range(20):
cache[-1]["double_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["total"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["img_mlp"] = {}
for j in range(40):
cache[-1]["single_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j]["total"] = {}
cache_dic["k-norm"][-1]["single_stream"][j] = {}
cache_dic["k-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["v-norm"][-1]["single_stream"][j] = {}
cache_dic["v-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["taylor_cache"] = False
cache_dic["duca"] = False
cache_dic["test_FLOPs"] = False
mode = "Taylor"
if mode == "original":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
elif mode == "ToCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = False
elif mode == "DuCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = True
elif mode == "Taylor":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 5
cache_dic["max_order"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["taylor_cache"] = True
cache_dic["first_enhance"] = 1
current = {}
current["num_steps"] = num_steps
current["activated_steps"] = [0]
return cache_dic, current
def force_scheduler(cache_dic, current):
if cache_dic["fresh_ratio"] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current["step"] / current["num_steps"])
threshold = torch.round(cache_dic["fresh_threshold"] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic["cal_threshold"] = threshold
# return threshold
def cal_type(cache_dic, current):
"""
Determine calculation type for this step
"""
if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]):
# FORA:Uniform
first_step = current["step"] == 0
else:
# ToCa: First enhanced
first_step = current["step"] < cache_dic["first_enhance"]
# first_step = (current['step'] <= 3)
force_fresh = cache_dic["force_fresh"]
if not first_step:
fresh_interval = cache_dic["cal_threshold"]
else:
fresh_interval = cache_dic["fresh_threshold"]
if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1):
current["type"] = "full"
cache_dic["cache_counter"] = 0
current["activated_steps"].append(current["step"])
# current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current)
elif cache_dic["taylor_cache"]:
cache_dic["cache_counter"] += 1
current["type"] = "taylor_cache"
else:
cache_dic["cache_counter"] += 1
if cache_dic["duca"]:
if cache_dic["cache_counter"] % 2 == 1: # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current["type"] = "ToCa"
# 'cache_noise' 'ToCa' 'FORA'
else:
current["type"] = "aggressive"
else:
current["type"] = "ToCa"
# if current['step'] < 25:
# current['type'] = 'FORA'
# else:
# current['type'] = 'aggressive'
class HunyuanSchedulerTeaCaching(HunyuanScheduler):
def __init__(self, args, image_encoder_output):
super().__init__(args, image_encoder_output)
self.cnt = 0
self.num_steps = self.args.infer_steps
self.teacache_thresh = self.args.teacache_thresh
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.coefficients = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
def clear(self):
if self.previous_residual is not None:
self.previous_residual = self.previous_residual.cpu()
if self.previous_modulated_input is not None:
self.previous_modulated_input = self.previous_modulated_input.cpu()
######################################################################
# if (current['step'] in [3,2,1,0]):
# current['type'] = 'full'
self.previous_modulated_input = None
self.previous_residual = None
torch.cuda.empty_cache()
class HunyuanSchedulerFeatureCaching(HunyuanScheduler):
def __init__(self, args):
super().__init__(args)
class HunyuanSchedulerTaylorCaching(HunyuanScheduler):
def __init__(self, args, image_encoder_output):
super().__init__(args, image_encoder_output)
self.cache_dic, self.current = cache_init(self.infer_steps)
def step_pre(self, step_index):
......
import torch
def cache_init(num_steps, model_kwargs=None):
"""
Initialization for cache.
"""
cache_dic = {}
cache = {}
cache_index = {}
cache[-1] = {}
cache_index[-1] = {}
cache_index["layer_index"] = {}
cache_dic["attn_map"] = {}
cache_dic["attn_map"][-1] = {}
cache_dic["attn_map"][-1]["double_stream"] = {}
cache_dic["attn_map"][-1]["single_stream"] = {}
cache_dic["k-norm"] = {}
cache_dic["k-norm"][-1] = {}
cache_dic["k-norm"][-1]["double_stream"] = {}
cache_dic["k-norm"][-1]["single_stream"] = {}
cache_dic["v-norm"] = {}
cache_dic["v-norm"][-1] = {}
cache_dic["v-norm"][-1]["double_stream"] = {}
cache_dic["v-norm"][-1]["single_stream"] = {}
cache_dic["cross_attn_map"] = {}
cache_dic["cross_attn_map"][-1] = {}
cache[-1]["double_stream"] = {}
cache[-1]["single_stream"] = {}
cache_dic["cache_counter"] = 0
for j in range(20):
cache[-1]["double_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["total"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["img_mlp"] = {}
for j in range(40):
cache[-1]["single_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j]["total"] = {}
cache_dic["k-norm"][-1]["single_stream"][j] = {}
cache_dic["k-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["v-norm"][-1]["single_stream"][j] = {}
cache_dic["v-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["taylor_cache"] = False
cache_dic["duca"] = False
cache_dic["test_FLOPs"] = False
mode = "Taylor"
if mode == "original":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
elif mode == "ToCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = False
elif mode == "DuCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = True
elif mode == "Taylor":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 5
cache_dic["max_order"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["taylor_cache"] = True
cache_dic["first_enhance"] = 1
current = {}
current["num_steps"] = num_steps
current["activated_steps"] = [0]
return cache_dic, current
def force_scheduler(cache_dic, current):
if cache_dic["fresh_ratio"] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current["step"] / current["num_steps"])
threshold = torch.round(cache_dic["fresh_threshold"] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic["cal_threshold"] = threshold
# return threshold
def cal_type(cache_dic, current):
"""
Determine calculation type for this step
"""
if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]):
# FORA:Uniform
first_step = current["step"] == 0
else:
# ToCa: First enhanced
first_step = current["step"] < cache_dic["first_enhance"]
# first_step = (current['step'] <= 3)
force_fresh = cache_dic["force_fresh"]
if not first_step:
fresh_interval = cache_dic["cal_threshold"]
else:
fresh_interval = cache_dic["fresh_threshold"]
if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1):
current["type"] = "full"
cache_dic["cache_counter"] = 0
current["activated_steps"].append(current["step"])
# current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current)
elif cache_dic["taylor_cache"]:
cache_dic["cache_counter"] += 1
current["type"] = "taylor_cache"
else:
cache_dic["cache_counter"] += 1
if cache_dic["duca"]:
if cache_dic["cache_counter"] % 2 == 1: # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current["type"] = "ToCa"
# 'cache_noise' 'ToCa' 'FORA'
else:
current["type"] = "aggressive"
else:
current["type"] = "ToCa"
# if current['step'] < 25:
# current['type'] = 'FORA'
# else:
# current['type'] = 'aggressive'
......@@ -10,3 +10,6 @@ class BaseScheduler:
def step_pre(self, step_index):
self.step_index = step_index
self.latents = self.latents.to(dtype=torch.bfloat16)
def clear(self):
pass
......@@ -2,7 +2,7 @@ import torch
from ..scheduler import WanScheduler
class WanSchedulerFeatureCaching(WanScheduler):
class WanSchedulerTeaCaching(WanScheduler):
def __init__(self, args):
super().__init__(args)
self.cnt = 0
......
......@@ -341,6 +341,3 @@ class WanScheduler(BaseScheduler):
self.lower_order_nums += 1
self.latents = prev_sample
def clear(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment