Commit 4eec372d authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support hunyuan offload and teacache. (#15)



* Support hunyuan offload and teacache.

* Fix

* Fix

---------
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent 8b76a6a5
import argparse import argparse
from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import os import os
...@@ -9,6 +8,7 @@ import json ...@@ -9,6 +8,7 @@ import json
import torchvision import torchvision
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
import numpy as np import numpy as np
from contextlib import contextmanager
from PIL import Image from PIL import Image
from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel
...@@ -16,15 +16,14 @@ from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel ...@@ -16,15 +16,14 @@ from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel
from lightx2v.text2v.models.text_encoders.hf.llava.model import TextEncoderHFLlavaModel from lightx2v.text2v.models.text_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.text2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler from lightx2v.text2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerFeatureCaching from lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.text2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.text2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerFeatureCaching from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerTeaCaching
from lightx2v.text2v.models.networks.hunyuan.model import HunyuanModel from lightx2v.text2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.text2v.models.networks.wan.model import WanModel from lightx2v.text2v.models.networks.wan.model import WanModel
from lightx2v.text2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.text2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
...@@ -34,8 +33,10 @@ from lightx2v.image2v.models.wan.model import CLIPModel ...@@ -34,8 +33,10 @@ from lightx2v.image2v.models.wan.model import CLIPModel
@contextmanager @contextmanager
def time_duration(label: str = ""): def time_duration(label: str = ""):
torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
yield yield
torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds") print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds")
...@@ -61,16 +62,16 @@ def load_models(args, model_config): ...@@ -61,16 +62,16 @@ def load_models(args, model_config):
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device, args=args) vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device, args=args)
elif args.model_cls == "wan2.1": elif args.model_cls == "wan2.1":
text_encoder = T5EncoderModel( with time_duration("Load Text Encoder"):
text_len=model_config["text_len"], text_encoder = T5EncoderModel(
dtype=torch.bfloat16, text_len=model_config["text_len"],
device=init_device, dtype=torch.bfloat16,
checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), device=init_device,
tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"), checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
shard_fn=None, tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"),
) shard_fn=None,
text_encoders = [text_encoder] )
text_encoders = [text_encoder]
with time_duration("Load Wan Model"): with time_duration("Load Wan Model"):
model = WanModel(args.model_path, model_config, init_device) model = WanModel(args.model_path, model_config, init_device)
...@@ -256,8 +257,10 @@ def init_scheduler(args, image_encoder_output): ...@@ -256,8 +257,10 @@ def init_scheduler(args, image_encoder_output):
if args.model_cls == "hunyuan": if args.model_cls == "hunyuan":
if args.feature_caching == "NoCaching": if args.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(args, image_encoder_output) scheduler = HunyuanScheduler(args, image_encoder_output)
elif args.feature_caching == "Tea":
scheduler = HunyuanSchedulerTeaCaching(args, image_encoder_output)
elif args.feature_caching == "TaylorSeer": elif args.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerFeatureCaching(args, image_encoder_output) scheduler = HunyuanSchedulerTaylorCaching(args, image_encoder_output)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
...@@ -265,7 +268,7 @@ def init_scheduler(args, image_encoder_output): ...@@ -265,7 +268,7 @@ def init_scheduler(args, image_encoder_output):
if args.feature_caching == "NoCaching": if args.feature_caching == "NoCaching":
scheduler = WanScheduler(args) scheduler = WanScheduler(args)
elif args.feature_caching == "Tea": elif args.feature_caching == "Tea":
scheduler = WanSchedulerFeatureCaching(args) scheduler = WanSchedulerTeaCaching(args)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
...@@ -338,7 +341,6 @@ if __name__ == "__main__": ...@@ -338,7 +341,6 @@ if __name__ == "__main__":
parser.add_argument("--use_bfloat16", action="store_true", default=True) parser.add_argument("--use_bfloat16", action="store_true", default=True)
parser.add_argument("--lora_path", type=str, default=None) parser.add_argument("--lora_path", type=str, default=None)
parser.add_argument("--strength_model", type=float, default=1.0) parser.add_argument("--strength_model", type=float, default=1.0)
args = parser.parse_args() args = parser.parse_args()
start_time = time.time() start_time = time.time()
...@@ -383,7 +385,8 @@ if __name__ == "__main__": ...@@ -383,7 +385,8 @@ if __name__ == "__main__":
else: else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None} image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config, image_encoder_output) with time_duration("Run Text Encoder"):
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config, image_encoder_output)
set_target_shape(args, image_encoder_output) set_target_shape(args, image_encoder_output)
scheduler = init_scheduler(args, image_encoder_output) scheduler = init_scheduler(args, image_encoder_output)
...@@ -399,16 +402,15 @@ if __name__ == "__main__": ...@@ -399,16 +402,15 @@ if __name__ == "__main__":
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache() torch.cuda.empty_cache()
images = run_vae(latents, generator, args) with time_duration("Run VAE"):
images = run_vae(latents, generator, args)
if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0): if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0):
save_video_st = time.time() with time_duration("Save video"):
if args.model_cls == "wan2.1": if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else: else:
save_videos_grid(images, args.save_video_path, fps=24) save_videos_grid(images, args.save_video_path, fps=24)
save_video_et = time.time()
print(f"Save video cost: {save_video_et - save_video_st}")
end_time = time.time() end_time = time.time()
print(f"Total cost: {end_time - start_time}") print(f"Total cost: {end_time - start_time}")
...@@ -36,11 +36,15 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -36,11 +36,15 @@ class MMWeightTemplate(metaclass=ABCMeta):
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking) self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking) self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False): def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking) self.weight = self.weight.cuda(non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking) self.bias = self.bias.cuda(non_blocking=non_blocking)
...@@ -109,7 +113,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -109,7 +113,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load_int8_perchannel_sym(self, weight_dict): def load_int8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32) self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = IntegerQuantizer(8, True, "per_channel") w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8) self.weight = self.weight.to(torch.int8)
...@@ -245,7 +249,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate): ...@@ -245,7 +249,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = Q8F.linear.fp8_linear(input_tensor_quant, self.weight, self.bias, input_tensor_scale, self.weight_scale, out_dtype=torch.bfloat16) output_tensor = Q8F.linear.fp8_linear(input_tensor_quant, self.weight, self.bias.float(), input_tensor_scale, self.weight_scale, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0) return output_tensor.squeeze(0)
...@@ -268,7 +272,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate): ...@@ -268,7 +272,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = Q8F.linear.q8_linear(input_tensor_quant, self.weight, self.bias, input_tensor_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16) output_tensor = Q8F.linear.q8_linear(input_tensor_quant, self.weight, self.bias.float(), input_tensor_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
return output_tensor.squeeze(0) return output_tensor.squeeze(0)
......
import torch import torch
import numpy as np
from einops import rearrange from einops import rearrange
from lightx2v.attentions import attention from lightx2v.attentions import attention
from .utils import taylor_cache_init, derivative_approximation, taylor_formula
from ..utils_bf16 import apply_rotary_emb from ..utils_bf16 import apply_rotary_emb
from typing import Dict
import math
from ..transformer_infer import HunyuanTransformerInfer from ..transformer_infer import HunyuanTransformerInfer
def taylor_cache_init(cache_dic: Dict, current: Dict): class HunyuanTransformerInferTeaCaching(HunyuanTransformerInfer):
""" def __init__(self, config):
Initialize Taylor cache, expanding storage areas for Taylor series derivatives super().__init__(config)
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if current["step"] == 0:
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = {}
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance = current["activated_steps"][-1] - current["activated_steps"][-2]
# difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors = {}
updated_taylor_factors[0] = feature
for i in range(cache_dic["max_order"]): def infer(
if (cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]].get(i, None) is not None) and (current["step"] > cache_dic["first_enhance"] - 2): self,
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i]) / difference_distance weights,
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec=None,
frist_frame_token_num=None,
):
inp = img.clone()
vec_ = vec.clone()
weights.double_blocks_weights[0].to_cuda()
img_mod1_shift, img_mod1_scale, _, _, _, _ = weights.double_blocks_weights[0].img_mod.apply(vec_).chunk(6, dim=-1)
weights.double_blocks_weights[0].to_cpu_sync()
normed_inp = torch.nn.functional.layer_norm(inp, (inp.shape[1],), None, None, 1e-6)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
del normed_inp, inp, vec_
if self.scheduler.cnt == 0 or self.scheduler.cnt == self.scheduler.num_steps - 1:
should_calc = True
self.scheduler.accumulated_rel_l1_distance = 0
else: else:
break rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance += rescale_func(
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = updated_taylor_factors ((modulated_inp - self.scheduler.previous_modulated_input).abs().mean() / self.scheduler.previous_modulated_input.abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance < self.scheduler.teacache_thresh:
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor: should_calc = False
""" else:
Compute Taylor expansion error should_calc = True
:param cache_dic: Cache dictionary self.scheduler.accumulated_rel_l1_distance = 0
:param current: Information of the current step self.scheduler.previous_modulated_input = modulated_inp
""" del modulated_inp
x = current["step"] - current["activated_steps"][-1]
# x = current['t'] - current['activated_times'][-1]
output = 0
for i in range(len(cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]])): if not should_calc:
output += (1 / math.factorial(i)) * cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i] * (x**i) img += self.scheduler.previous_residual
else:
ori_img = img.clone()
img, vec = super().infer(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
self.scheduler.previous_residual = img - ori_img
del ori_img
torch.cuda.empty_cache()
return output return img, vec
class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert not self.config["cpu_offload"], "Not support cpu-offload for TaylorCaching"
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
txt_seq_len = txt.shape[0] txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0] img_seq_len = img.shape[0]
self.scheduler.current["stream"] = "double_stream" self.scheduler.current["stream"] = "double_stream"
for i in range(self.double_blocks_num): for i in range(self.double_blocks_num):
self.scheduler.current["layer"] = i self.scheduler.current["layer"] = i
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
x = torch.cat((img, txt), 0) x = torch.cat((img, txt), 0)
self.scheduler.current["stream"] = "single_stream" self.scheduler.current["stream"] = "single_stream"
for i in range(self.single_blocks_num): for i in range(self.single_blocks_num):
self.scheduler.current["layer"] = i self.scheduler.current["layer"] = i
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
...@@ -133,8 +142,24 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -133,8 +142,24 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
) )
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :] img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) img = self.infer_double_block_img_post_atten(
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
)
txt = self.infer_double_block_txt_post_atten(
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
return img, txt return img, txt
elif self.scheduler.current["type"] == "taylor_cache": elif self.scheduler.current["type"] == "taylor_cache":
...@@ -166,7 +191,16 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -166,7 +191,16 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
return img, txt return img, txt
def infer_double_block_img_post_atten(self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate): def infer_double_block_img_post_atten(
self,
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
):
self.scheduler.current["module"] = "img_attn" self.scheduler.current["module"] = "img_attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current) taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
...@@ -190,7 +224,16 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -190,7 +224,16 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
img = img + out img = img + out
return img return img
def infer_double_block_txt_post_atten(self, weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate): def infer_double_block_txt_post_atten(
self,
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
):
self.scheduler.current["module"] = "txt_attn" self.scheduler.current["module"] = "txt_attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current) taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
......
from typing import Dict
import math
import torch
def taylor_cache_init(cache_dic: Dict, current: Dict):
"""
Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if current["step"] == 0:
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = {}
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance = current["activated_steps"][-1] - current["activated_steps"][-2]
# difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors = {}
updated_taylor_factors[0] = feature
for i in range(cache_dic["max_order"]):
if (cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]].get(i, None) is not None) and (current["step"] > cache_dic["first_enhance"] - 2):
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i]) / difference_distance
else:
break
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = updated_taylor_factors
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
"""
Compute Taylor expansion error
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x = current["step"] - current["activated_steps"][-1]
# x = current['t'] - current['activated_times'][-1]
output = 0
for i in range(len(cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]])):
output += (1 / math.factorial(i)) * cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i] * (x**i)
return output
...@@ -38,15 +38,7 @@ class HunyuanTransformerInfer: ...@@ -38,15 +38,7 @@ class HunyuanTransformerInfer:
self.double_weights_stream_mgr.active_weights[0].to_cuda() self.double_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream): with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream):
img, txt = self.infer_double_block( img, txt = self.infer_double_block(self.double_weights_stream_mgr.active_weights[0], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
self.double_weights_stream_mgr.active_weights[0],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
if double_block_idx < self.double_blocks_num - 1: if double_block_idx < self.double_blocks_num - 1:
self.double_weights_stream_mgr.prefetch_weights(double_block_idx + 1, weights.double_blocks_weights) self.double_weights_stream_mgr.prefetch_weights(double_block_idx + 1, weights.double_blocks_weights)
...@@ -54,23 +46,21 @@ class HunyuanTransformerInfer: ...@@ -54,23 +46,21 @@ class HunyuanTransformerInfer:
x = torch.cat((img, txt), 0) x = torch.cat((img, txt), 0)
img = img.cpu()
txt = txt.cpu()
del img, txt
torch.cuda.empty_cache()
for single_block_idx in range(self.single_blocks_num): for single_block_idx in range(self.single_blocks_num):
if single_block_idx == 0: if single_block_idx == 0:
self.single_weights_stream_mgr.active_weights[0] = weights.single_blocks_weights[0] self.single_weights_stream_mgr.active_weights[0] = weights.single_blocks_weights[0]
self.single_weights_stream_mgr.active_weights[0].to_cuda() self.single_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream): with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream):
x = self.infer_single_block( x = self.infer_single_block(self.single_weights_stream_mgr.active_weights[0], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
weights.single_blocks_weights[single_block_idx],
x,
vec,
txt_seq_len,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
if single_block_idx < self.single_blocks_num - 1: if single_block_idx < self.single_blocks_num - 1:
self.single_weights_stream_mgr.prefetch_weights(single_block_idx + 1, weights.single_blocks_weights) self.single_weights_stream_mgr.prefetch_weights(single_block_idx + 1, weights.single_blocks_weights)
self.single_weights_stream_mgr.swap_weights() self.single_weights_stream_mgr.swap_weights()
torch.cuda.empty_cache()
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
......
...@@ -6,7 +6,7 @@ from lightx2v.text2v.models.networks.hunyuan.weights.transformer_weights import ...@@ -6,7 +6,7 @@ from lightx2v.text2v.models.networks.hunyuan.weights.transformer_weights import
from lightx2v.text2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer from lightx2v.text2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.text2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer from lightx2v.text2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer from lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferFeatureCaching from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferTaylorCaching, HunyuanTransformerInferTeaCaching
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
...@@ -43,7 +43,9 @@ class HunyuanModel: ...@@ -43,7 +43,9 @@ class HunyuanModel:
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer self.transformer_infer_class = HunyuanTransformerInfer
elif self.config["feature_caching"] == "TaylorSeer": elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferFeatureCaching self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
...@@ -107,3 +109,8 @@ class HunyuanModel: ...@@ -107,3 +109,8 @@ class HunyuanModel:
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.post_weight.to_cpu()
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps:
self.scheduler.cnt = 0
...@@ -3,7 +3,7 @@ from ..transformer_infer import WanTransformerInfer ...@@ -3,7 +3,7 @@ from ..transformer_infer import WanTransformerInfer
import torch import torch
class WanTransformerInferFeatureCaching(WanTransformerInfer): class WanTransformerInferTeaCaching(WanTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -12,7 +12,7 @@ from lightx2v.text2v.models.networks.wan.infer.post_infer import WanPostInfer ...@@ -12,7 +12,7 @@ from lightx2v.text2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.text2v.models.networks.wan.infer.transformer_infer import ( from lightx2v.text2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer, WanTransformerInfer,
) )
from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferFeatureCaching from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
...@@ -49,7 +49,7 @@ class WanModel: ...@@ -49,7 +49,7 @@ class WanModel:
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea": elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferFeatureCaching self.transformer_infer_class = WanTransformerInferTeaCaching
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......
import torch from .utils import cache_init, cal_type
from ..scheduler import HunyuanScheduler from ..scheduler import HunyuanScheduler
import torch
def cache_init(num_steps, model_kwargs=None): class HunyuanSchedulerTeaCaching(HunyuanScheduler):
""" def __init__(self, args, image_encoder_output):
Initialization for cache. super().__init__(args, image_encoder_output)
""" self.cnt = 0
cache_dic = {} self.num_steps = self.args.infer_steps
cache = {} self.teacache_thresh = self.args.teacache_thresh
cache_index = {} self.accumulated_rel_l1_distance = 0
cache[-1] = {} self.previous_modulated_input = None
cache_index[-1] = {} self.previous_residual = None
cache_index["layer_index"] = {} self.coefficients = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
cache_dic["attn_map"] = {}
cache_dic["attn_map"][-1] = {}
cache_dic["attn_map"][-1]["double_stream"] = {}
cache_dic["attn_map"][-1]["single_stream"] = {}
cache_dic["k-norm"] = {}
cache_dic["k-norm"][-1] = {}
cache_dic["k-norm"][-1]["double_stream"] = {}
cache_dic["k-norm"][-1]["single_stream"] = {}
cache_dic["v-norm"] = {}
cache_dic["v-norm"][-1] = {}
cache_dic["v-norm"][-1]["double_stream"] = {}
cache_dic["v-norm"][-1]["single_stream"] = {}
cache_dic["cross_attn_map"] = {}
cache_dic["cross_attn_map"][-1] = {}
cache[-1]["double_stream"] = {}
cache[-1]["single_stream"] = {}
cache_dic["cache_counter"] = 0
for j in range(20):
cache[-1]["double_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["total"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["img_mlp"] = {}
for j in range(40):
cache[-1]["single_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j]["total"] = {}
cache_dic["k-norm"][-1]["single_stream"][j] = {}
cache_dic["k-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["v-norm"][-1]["single_stream"][j] = {}
cache_dic["v-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["taylor_cache"] = False
cache_dic["duca"] = False
cache_dic["test_FLOPs"] = False
mode = "Taylor"
if mode == "original":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
elif mode == "ToCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = False
elif mode == "DuCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = True
elif mode == "Taylor":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 5
cache_dic["max_order"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["taylor_cache"] = True
cache_dic["first_enhance"] = 1
current = {}
current["num_steps"] = num_steps
current["activated_steps"] = [0]
return cache_dic, current
def force_scheduler(cache_dic, current):
if cache_dic["fresh_ratio"] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current["step"] / current["num_steps"])
threshold = torch.round(cache_dic["fresh_threshold"] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic["cal_threshold"] = threshold
# return threshold
def cal_type(cache_dic, current):
"""
Determine calculation type for this step
"""
if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]):
# FORA:Uniform
first_step = current["step"] == 0
else:
# ToCa: First enhanced
first_step = current["step"] < cache_dic["first_enhance"]
# first_step = (current['step'] <= 3)
force_fresh = cache_dic["force_fresh"]
if not first_step:
fresh_interval = cache_dic["cal_threshold"]
else:
fresh_interval = cache_dic["fresh_threshold"]
if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1):
current["type"] = "full"
cache_dic["cache_counter"] = 0
current["activated_steps"].append(current["step"])
# current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current)
elif cache_dic["taylor_cache"]:
cache_dic["cache_counter"] += 1
current["type"] = "taylor_cache"
else:
cache_dic["cache_counter"] += 1
if cache_dic["duca"]:
if cache_dic["cache_counter"] % 2 == 1: # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current["type"] = "ToCa"
# 'cache_noise' 'ToCa' 'FORA'
else:
current["type"] = "aggressive"
else:
current["type"] = "ToCa"
# if current['step'] < 25:
# current['type'] = 'FORA'
# else:
# current['type'] = 'aggressive'
def clear(self):
if self.previous_residual is not None:
self.previous_residual = self.previous_residual.cpu()
if self.previous_modulated_input is not None:
self.previous_modulated_input = self.previous_modulated_input.cpu()
###################################################################### self.previous_modulated_input = None
# if (current['step'] in [3,2,1,0]): self.previous_residual = None
# current['type'] = 'full' torch.cuda.empty_cache()
class HunyuanSchedulerFeatureCaching(HunyuanScheduler): class HunyuanSchedulerTaylorCaching(HunyuanScheduler):
def __init__(self, args): def __init__(self, args, image_encoder_output):
super().__init__(args) super().__init__(args, image_encoder_output)
self.cache_dic, self.current = cache_init(self.infer_steps) self.cache_dic, self.current = cache_init(self.infer_steps)
def step_pre(self, step_index): def step_pre(self, step_index):
......
import torch
def cache_init(num_steps, model_kwargs=None):
"""
Initialization for cache.
"""
cache_dic = {}
cache = {}
cache_index = {}
cache[-1] = {}
cache_index[-1] = {}
cache_index["layer_index"] = {}
cache_dic["attn_map"] = {}
cache_dic["attn_map"][-1] = {}
cache_dic["attn_map"][-1]["double_stream"] = {}
cache_dic["attn_map"][-1]["single_stream"] = {}
cache_dic["k-norm"] = {}
cache_dic["k-norm"][-1] = {}
cache_dic["k-norm"][-1]["double_stream"] = {}
cache_dic["k-norm"][-1]["single_stream"] = {}
cache_dic["v-norm"] = {}
cache_dic["v-norm"][-1] = {}
cache_dic["v-norm"][-1]["double_stream"] = {}
cache_dic["v-norm"][-1]["single_stream"] = {}
cache_dic["cross_attn_map"] = {}
cache_dic["cross_attn_map"][-1] = {}
cache[-1]["double_stream"] = {}
cache[-1]["single_stream"] = {}
cache_dic["cache_counter"] = 0
for j in range(20):
cache[-1]["double_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["total"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["img_mlp"] = {}
for j in range(40):
cache[-1]["single_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j]["total"] = {}
cache_dic["k-norm"][-1]["single_stream"][j] = {}
cache_dic["k-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["v-norm"][-1]["single_stream"][j] = {}
cache_dic["v-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["taylor_cache"] = False
cache_dic["duca"] = False
cache_dic["test_FLOPs"] = False
mode = "Taylor"
if mode == "original":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
elif mode == "ToCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = False
elif mode == "DuCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = True
elif mode == "Taylor":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 5
cache_dic["max_order"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["taylor_cache"] = True
cache_dic["first_enhance"] = 1
current = {}
current["num_steps"] = num_steps
current["activated_steps"] = [0]
return cache_dic, current
def force_scheduler(cache_dic, current):
if cache_dic["fresh_ratio"] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current["step"] / current["num_steps"])
threshold = torch.round(cache_dic["fresh_threshold"] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic["cal_threshold"] = threshold
# return threshold
def cal_type(cache_dic, current):
"""
Determine calculation type for this step
"""
if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]):
# FORA:Uniform
first_step = current["step"] == 0
else:
# ToCa: First enhanced
first_step = current["step"] < cache_dic["first_enhance"]
# first_step = (current['step'] <= 3)
force_fresh = cache_dic["force_fresh"]
if not first_step:
fresh_interval = cache_dic["cal_threshold"]
else:
fresh_interval = cache_dic["fresh_threshold"]
if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1):
current["type"] = "full"
cache_dic["cache_counter"] = 0
current["activated_steps"].append(current["step"])
# current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current)
elif cache_dic["taylor_cache"]:
cache_dic["cache_counter"] += 1
current["type"] = "taylor_cache"
else:
cache_dic["cache_counter"] += 1
if cache_dic["duca"]:
if cache_dic["cache_counter"] % 2 == 1: # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current["type"] = "ToCa"
# 'cache_noise' 'ToCa' 'FORA'
else:
current["type"] = "aggressive"
else:
current["type"] = "ToCa"
# if current['step'] < 25:
# current['type'] = 'FORA'
# else:
# current['type'] = 'aggressive'
...@@ -10,3 +10,6 @@ class BaseScheduler: ...@@ -10,3 +10,6 @@ class BaseScheduler:
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
self.latents = self.latents.to(dtype=torch.bfloat16) self.latents = self.latents.to(dtype=torch.bfloat16)
def clear(self):
pass
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from ..scheduler import WanScheduler from ..scheduler import WanScheduler
class WanSchedulerFeatureCaching(WanScheduler): class WanSchedulerTeaCaching(WanScheduler):
def __init__(self, args): def __init__(self, args):
super().__init__(args) super().__init__(args)
self.cnt = 0 self.cnt = 0
......
...@@ -341,6 +341,3 @@ class WanScheduler(BaseScheduler): ...@@ -341,6 +341,3 @@ class WanScheduler(BaseScheduler):
self.lower_order_nums += 1 self.lower_order_nums += 1
self.latents = prev_sample self.latents = prev_sample
def clear(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment