Commit dcaefe63 authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

update feature caching (#78)


Co-authored-by: default avatarLinboyan-trc <1584340372@qq.com>
parent bff9bd05
from abc import ABC, abstractmethod
import torch
import math
class BaseTransformerInfer(ABC):
@abstractmethod
def infer(self):
pass
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.scheduler.transformer_infer = self
class BaseTaylorCachingTransformerInfer(BaseTransformerInfer):
@abstractmethod
def infer_calculating(self):
pass
@abstractmethod
def infer_using_cache(self):
pass
@abstractmethod
def get_taylor_step_diff(self):
pass
# 1. when fully calcualted, stored in cache
def derivative_approximation(self, block_cache, module_name, out):
if module_name not in block_cache:
block_cache[module_name] = {0: out}
else:
step_diff = self.get_taylor_step_diff()
previous_out = block_cache[module_name][0]
block_cache[module_name][0] = out
block_cache[module_name][1] = (out - previous_out) / step_diff
def taylor_formula(self, tensor_dict):
x = self.get_taylor_step_diff()
output = 0
for i in range(len(tensor_dict)):
output += (1 / math.factorial(i)) * tensor_dict[i] * (x**i)
return output
......@@ -4,10 +4,11 @@ from lightx2v.common.offload.manager import (
WeightAsyncStreamManager,
LazyWeightAsyncStreamManager,
)
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
class WanTransformerInfer:
class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
self.task = config["task"]
......@@ -49,8 +50,10 @@ class WanTransformerInfer:
else:
self.infer_func = self._infer_without_offload
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.infer_conditional = True
def switch_status(self):
self.infer_conditional = not self.infer_conditional
def _calculate_q_k_len(self, q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
......
......@@ -15,6 +15,9 @@ from lightx2v.models.networks.wan.infer.transformer_infer import (
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
WanTransformerInferTaylorCaching,
WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching,
)
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
......@@ -59,6 +62,12 @@ class WanModel:
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Taylor":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......@@ -201,10 +210,6 @@ class WanModel:
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
......@@ -212,11 +217,6 @@ class WanModel:
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
......
......@@ -9,6 +9,9 @@ from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTeaCaching,
WanSchedulerTaylorCaching,
WanSchedulerAdaCaching,
WanSchedulerCustomCaching,
)
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
......@@ -114,6 +117,12 @@ class WanRunner(DefaultRunner):
scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "Taylor":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
......
......@@ -7,7 +7,10 @@ class BaseScheduler:
self.config = config
self.step_index = 0
self.latents = None
self.infer_steps = config.infer_steps
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
def step_pre(self, step_index):
self.step_index = step_index
......
import torch
from ..scheduler import WanScheduler
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanSchedulerTeaCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.cnt = 0
self.num_steps = self.config.infer_steps * 2
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
self.use_ret_steps = self.config.use_ret_steps
if self.use_ret_steps:
self.coefficients = self.config.coefficients[0]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
self.coefficients = self.config.coefficients[1]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
def clear(self):
if self.previous_e0_even is not None:
self.previous_e0_even = self.previous_e0_even.cpu()
if self.previous_e0_odd is not None:
self.previous_e0_odd = self.previous_e0_odd.cpu()
if self.previous_residual_even is not None:
self.previous_residual_even = self.previous_residual_even.cpu()
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu()
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
torch.cuda.empty_cache()
self.transformer_infer.clear()
class WanSchedulerTaylorCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
pattern = [True, False, False, False]
self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
self.caching_records_2 = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
def clear(self):
self.transformer_infer.clear()
class WanSchedulerAdaCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerCustomCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
......@@ -18,6 +18,8 @@ class WanScheduler(BaseScheduler):
self.solver_order = 2
self.noise_pred = None
self.caching_records_2 = [True] * self.config.infer_steps
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment