Commit 220a631f authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

update hunyuan cache (#79)


Co-authored-by: default avatarLinboyan-trc <1584340372@qq.com>
parent 9da774a7
...@@ -7,8 +7,12 @@ from lightx2v.models.networks.hunyuan.weights.transformer_weights import Hunyuan ...@@ -7,8 +7,12 @@ from lightx2v.models.networks.hunyuan.weights.transformer_weights import Hunyuan
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferTaylorCaching, HunyuanTransformerInferTeaCaching from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import (
HunyuanTransformerInferTaylorCaching,
HunyuanTransformerInferTeaCaching,
HunyuanTransformerInferAdaCaching,
HunyuanTransformerInferCustomCaching,
)
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -156,10 +160,6 @@ class HunyuanModel: ...@@ -156,10 +160,6 @@ class HunyuanModel:
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.post_weight.to_cpu()
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps:
self.scheduler.cnt = 0
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer self.pre_infer_class = HunyuanPreInfer
...@@ -170,5 +170,9 @@ class HunyuanModel: ...@@ -170,5 +170,9 @@ class HunyuanModel:
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea": elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching self.transformer_infer_class = HunyuanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = HunyuanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = HunyuanTransformerInferCustomCaching
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
...@@ -305,13 +305,23 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -305,13 +305,23 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
for cache in self.blocks_cache_even: for cache in self.blocks_cache_even:
for key in cache: for key in cache:
if cache[key] is not None: if cache[key] is not None:
cache[key] = cache[key].cpu() if isinstance(cache[key], torch.Tensor):
cache[key] = cache[key].cpu()
elif isinstance(cache[key], dict):
for k, v in cache[key].items():
if isinstance(v, torch.Tensor):
cache[key][k] = v.cpu()
cache.clear() cache.clear()
for cache in self.blocks_cache_odd: for cache in self.blocks_cache_odd:
for key in cache: for key in cache:
if cache[key] is not None: if cache[key] is not None:
cache[key] = cache[key].cpu() if isinstance(cache[key], torch.Tensor):
cache[key] = cache[key].cpu()
elif isinstance(cache[key], dict):
for k, v in cache[key].items():
if isinstance(v, torch.Tensor):
cache[key][k] = v.cpu()
cache.clear() cache.clear()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -62,7 +62,7 @@ class WanModel: ...@@ -62,7 +62,7 @@ class WanModel:
self.transformer_infer_class = WanTransformerInfer self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea": elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Taylor": elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = WanTransformerInferTaylorCaching self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada": elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching self.transformer_infer_class = WanTransformerInferAdaCaching
......
...@@ -6,7 +6,7 @@ from PIL import Image ...@@ -6,7 +6,7 @@ from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching, HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
...@@ -47,6 +47,10 @@ class HunyuanRunner(DefaultRunner): ...@@ -47,6 +47,10 @@ class HunyuanRunner(DefaultRunner):
scheduler = HunyuanSchedulerTeaCaching(self.config) scheduler = HunyuanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer": elif self.config.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(self.config) scheduler = HunyuanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = HunyuanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = HunyuanSchedulerCustomCaching(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
......
...@@ -117,7 +117,7 @@ class WanRunner(DefaultRunner): ...@@ -117,7 +117,7 @@ class WanRunner(DefaultRunner):
scheduler = WanScheduler(self.config) scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea": elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config) scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "Taylor": elif self.config.feature_caching == "TaylorSeer":
scheduler = WanSchedulerTaylorCaching(self.config) scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada": elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config) scheduler = WanSchedulerAdaCaching(self.config)
......
from .utils import cache_init, cal_type
from ..scheduler import HunyuanScheduler from ..scheduler import HunyuanScheduler
import torch import torch
...@@ -6,31 +5,32 @@ import torch ...@@ -6,31 +5,32 @@ import torch
class HunyuanSchedulerTeaCaching(HunyuanScheduler): class HunyuanSchedulerTeaCaching(HunyuanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.cnt = 0
self.num_steps = self.config.infer_steps
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.coefficients = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
def clear(self): def clear(self):
if self.previous_residual is not None: self.transformer_infer.clear()
self.previous_residual = self.previous_residual.cpu()
if self.previous_modulated_input is not None:
self.previous_modulated_input = self.previous_modulated_input.cpu()
self.previous_modulated_input = None
self.previous_residual = None
torch.cuda.empty_cache()
class HunyuanSchedulerTaylorCaching(HunyuanScheduler): class HunyuanSchedulerTaylorCaching(HunyuanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.cache_dic, self.current = cache_init(self.infer_steps) pattern = [True, False, False, False]
self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
def clear(self):
self.transformer_infer.clear()
class HunyuanSchedulerAdaCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def step_pre(self, step_index): def clear(self):
super().step_pre(step_index) self.transformer_infer.clear()
self.current["step"] = step_index
cal_type(self.cache_dic, self.current)
class HunyuanSchedulerCustomCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
...@@ -237,7 +237,6 @@ def get_1d_rotary_pos_embed_riflex( ...@@ -237,7 +237,6 @@ def get_1d_rotary_pos_embed_riflex(
class HunyuanScheduler(BaseScheduler): class HunyuanScheduler(BaseScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.infer_steps = self.config.infer_steps
self.shift = 7.0 self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda")) self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps assert len(self.timesteps) == self.infer_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment