"tests/onebitadam/test_server_error.py" did not exist on "ab5534fc4c0f8ca21ada321f9730d723aa31288b"
Commit 220a631f authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

update hunyuan cache (#79)


Co-authored-by: default avatarLinboyan-trc <1584340372@qq.com>
parent 9da774a7
......@@ -7,8 +7,12 @@ from lightx2v.models.networks.hunyuan.weights.transformer_weights import Hunyuan
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferTaylorCaching, HunyuanTransformerInferTeaCaching
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import (
HunyuanTransformerInferTaylorCaching,
HunyuanTransformerInferTeaCaching,
HunyuanTransformerInferAdaCaching,
HunyuanTransformerInferCustomCaching,
)
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
......@@ -156,10 +160,6 @@ class HunyuanModel:
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps:
self.scheduler.cnt = 0
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
......@@ -170,5 +170,9 @@ class HunyuanModel:
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = HunyuanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = HunyuanTransformerInferCustomCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......@@ -305,13 +305,23 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
for cache in self.blocks_cache_even:
for key in cache:
if cache[key] is not None:
cache[key] = cache[key].cpu()
if isinstance(cache[key], torch.Tensor):
cache[key] = cache[key].cpu()
elif isinstance(cache[key], dict):
for k, v in cache[key].items():
if isinstance(v, torch.Tensor):
cache[key][k] = v.cpu()
cache.clear()
for cache in self.blocks_cache_odd:
for key in cache:
if cache[key] is not None:
cache[key] = cache[key].cpu()
if isinstance(cache[key], torch.Tensor):
cache[key] = cache[key].cpu()
elif isinstance(cache[key], dict):
for k, v in cache[key].items():
if isinstance(v, torch.Tensor):
cache[key][k] = v.cpu()
cache.clear()
torch.cuda.empty_cache()
......
......@@ -62,7 +62,7 @@ class WanModel:
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Taylor":
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
......
......@@ -6,7 +6,7 @@ from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching, HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
......@@ -47,6 +47,10 @@ class HunyuanRunner(DefaultRunner):
scheduler = HunyuanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = HunyuanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = HunyuanSchedulerCustomCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
......
......@@ -117,7 +117,7 @@ class WanRunner(DefaultRunner):
scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "Taylor":
elif self.config.feature_caching == "TaylorSeer":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
......
from .utils import cache_init, cal_type
from ..scheduler import HunyuanScheduler
import torch
......@@ -6,31 +5,32 @@ import torch
class HunyuanSchedulerTeaCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
self.cnt = 0
self.num_steps = self.config.infer_steps
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.coefficients = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
def clear(self):
if self.previous_residual is not None:
self.previous_residual = self.previous_residual.cpu()
if self.previous_modulated_input is not None:
self.previous_modulated_input = self.previous_modulated_input.cpu()
self.previous_modulated_input = None
self.previous_residual = None
torch.cuda.empty_cache()
self.transformer_infer.clear()
class HunyuanSchedulerTaylorCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
self.cache_dic, self.current = cache_init(self.infer_steps)
pattern = [True, False, False, False]
self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
def clear(self):
self.transformer_infer.clear()
class HunyuanSchedulerAdaCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def step_pre(self, step_index):
super().step_pre(step_index)
self.current["step"] = step_index
cal_type(self.cache_dic, self.current)
def clear(self):
self.transformer_infer.clear()
class HunyuanSchedulerCustomCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
......@@ -237,7 +237,6 @@ def get_1d_rotary_pos_embed_riflex(
class HunyuanScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.infer_steps = self.config.infer_steps
self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment