Commit dcaefe63 authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

update feature caching (#78)


Co-authored-by: default avatarLinboyan-trc <1584340372@qq.com>
parent bff9bd05
from abc import ABC, abstractmethod
import torch
import math
class BaseTransformerInfer(ABC):
@abstractmethod
def infer(self):
pass
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.scheduler.transformer_infer = self
class BaseTaylorCachingTransformerInfer(BaseTransformerInfer):
@abstractmethod
def infer_calculating(self):
pass
@abstractmethod
def infer_using_cache(self):
pass
@abstractmethod
def get_taylor_step_diff(self):
pass
# 1. when fully calcualted, stored in cache
def derivative_approximation(self, block_cache, module_name, out):
if module_name not in block_cache:
block_cache[module_name] = {0: out}
else:
step_diff = self.get_taylor_step_diff()
previous_out = block_cache[module_name][0]
block_cache[module_name][0] = out
block_cache[module_name][1] = (out - previous_out) / step_diff
def taylor_formula(self, tensor_dict):
x = self.get_taylor_step_diff()
output = 0
for i in range(len(tensor_dict)):
output += (1 / math.factorial(i)) * tensor_dict[i] * (x**i)
return output
...@@ -4,10 +4,11 @@ from lightx2v.common.offload.manager import ( ...@@ -4,10 +4,11 @@ from lightx2v.common.offload.manager import (
WeightAsyncStreamManager, WeightAsyncStreamManager,
LazyWeightAsyncStreamManager, LazyWeightAsyncStreamManager,
) )
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
class WanTransformerInfer: class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.task = config["task"] self.task = config["task"]
...@@ -49,8 +50,10 @@ class WanTransformerInfer: ...@@ -49,8 +50,10 @@ class WanTransformerInfer:
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
def set_scheduler(self, scheduler): self.infer_conditional = True
self.scheduler = scheduler
def switch_status(self):
self.infer_conditional = not self.infer_conditional
def _calculate_q_k_len(self, q, k_lens): def _calculate_q_k_len(self, q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
......
...@@ -15,6 +15,9 @@ from lightx2v.models.networks.wan.infer.transformer_infer import ( ...@@ -15,6 +15,9 @@ from lightx2v.models.networks.wan.infer.transformer_infer import (
) )
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import ( from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching, WanTransformerInferTeaCaching,
WanTransformerInferTaylorCaching,
WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching,
) )
from safetensors import safe_open from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
...@@ -59,6 +62,12 @@ class WanModel: ...@@ -59,6 +62,12 @@ class WanModel:
self.transformer_infer_class = WanTransformerInfer self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea": elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Taylor":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
...@@ -201,10 +210,6 @@ class WanModel: ...@@ -201,10 +210,6 @@ class WanModel:
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]: if self.config["enable_cfg"]:
...@@ -212,11 +217,6 @@ class WanModel: ...@@ -212,11 +217,6 @@ class WanModel:
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
......
...@@ -9,6 +9,9 @@ from lightx2v.models.runners.default_runner import DefaultRunner ...@@ -9,6 +9,9 @@ from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import ( from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTeaCaching, WanSchedulerTeaCaching,
WanSchedulerTaylorCaching,
WanSchedulerAdaCaching,
WanSchedulerCustomCaching,
) )
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
...@@ -114,6 +117,12 @@ class WanRunner(DefaultRunner): ...@@ -114,6 +117,12 @@ class WanRunner(DefaultRunner):
scheduler = WanScheduler(self.config) scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea": elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config) scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "Taylor":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
......
...@@ -7,7 +7,10 @@ class BaseScheduler: ...@@ -7,7 +7,10 @@ class BaseScheduler:
self.config = config self.config = config
self.step_index = 0 self.step_index = 0
self.latents = None self.latents = None
self.infer_steps = config.infer_steps
self.caching_records = [True] * config.infer_steps
self.flag_df = False self.flag_df = False
self.transformer_infer = None
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
......
import torch from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from ..scheduler import WanScheduler
class WanSchedulerTeaCaching(WanScheduler): class WanSchedulerTeaCaching(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.cnt = 0
self.num_steps = self.config.infer_steps * 2
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
self.use_ret_steps = self.config.use_ret_steps
if self.use_ret_steps:
self.coefficients = self.config.coefficients[0]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
self.coefficients = self.config.coefficients[1]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
def clear(self): def clear(self):
if self.previous_e0_even is not None: self.transformer_infer.clear()
self.previous_e0_even = self.previous_e0_even.cpu()
if self.previous_e0_odd is not None:
self.previous_e0_odd = self.previous_e0_odd.cpu() class WanSchedulerTaylorCaching(WanScheduler):
if self.previous_residual_even is not None: def __init__(self, config):
self.previous_residual_even = self.previous_residual_even.cpu() super().__init__(config)
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu() pattern = [True, False, False, False]
self.previous_e0_even = None self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
self.previous_e0_odd = None self.caching_records_2 = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
self.previous_residual_even = None
self.previous_residual_odd = None def clear(self):
torch.cuda.empty_cache() self.transformer_infer.clear()
class WanSchedulerAdaCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerCustomCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
...@@ -18,6 +18,8 @@ class WanScheduler(BaseScheduler): ...@@ -18,6 +18,8 @@ class WanScheduler(BaseScheduler):
self.solver_order = 2 self.solver_order = 2
self.noise_pred = None self.noise_pred = None
self.caching_records_2 = [True] * self.config.infer_steps
def prepare(self, image_encoder_output=None): def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device) self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed) self.generator.manual_seed(self.config.seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment