Commit dcaefe63 authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

update feature caching (#78)


Co-authored-by: default avatarLinboyan-trc <1584340372@qq.com>
parent bff9bd05
from abc import ABC, abstractmethod
import torch
import math
class BaseTransformerInfer(ABC):
@abstractmethod
def infer(self):
pass
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.scheduler.transformer_infer = self
class BaseTaylorCachingTransformerInfer(BaseTransformerInfer):
@abstractmethod
def infer_calculating(self):
pass
@abstractmethod
def infer_using_cache(self):
pass
@abstractmethod
def get_taylor_step_diff(self):
pass
# 1. when fully calcualted, stored in cache
def derivative_approximation(self, block_cache, module_name, out):
if module_name not in block_cache:
block_cache[module_name] = {0: out}
else:
step_diff = self.get_taylor_step_diff()
previous_out = block_cache[module_name][0]
block_cache[module_name][0] = out
block_cache[module_name][1] = (out - previous_out) / step_diff
def taylor_formula(self, tensor_dict):
x = self.get_taylor_step_diff()
output = 0
for i in range(len(tensor_dict)):
output += (1 / math.factorial(i)) * tensor_dict[i] * (x**i)
return output
import numpy as np
from ..transformer_infer import WanTransformerInfer
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
import torch
import gc
import numpy as np
# 1. TeaCaching
class WanTransformerInferTeaCaching(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.teacache_thresh = config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None
self.previous_residual_even = None
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = None
self.previous_residual_odd = None
self.use_ret_steps = config.use_ret_steps
self.set_attributes_by_task_and_model()
self.cnt = 0
# only in Wan2.1 TeaCaching
def set_attributes_by_task_and_model(self):
if self.config.task == "i2v":
if self.use_ret_steps:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
]
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
]
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
-114.36346466,
65.26524496,
-18.82220707,
4.91518089,
-0.23412683,
]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
elif self.config.task == "t2v":
if self.use_ret_steps:
if "1.3B" in self.config.model_path:
self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
if "14B" in self.config.model_path:
self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if "1.3B" in self.config.model_path:
self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
if "14B" in self.config.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
# calculate should_calc
def calculate_should_calc(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
# 1. timestep embedding
modulated_inp = embed0 if self.use_ret_steps else embed
# 2. L1 calculate
should_calc = False
if self.infer_conditional:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc = True
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()
else:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc = True
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
# 3. return the judgement
return should_calc
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if self.infer_conditional:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.scheduler.caching_records[index] = should_calc
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
if self.config.enable_cfg:
self.switch_status()
self.cnt += 1
return x
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(self.blocks_num):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
y_out = self.infer_phase_2(weights.blocks[block_idx].compute_phases[0], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
attn_out = self.infer_phase_3(weights.blocks[block_idx].compute_phases[1], x, context, y_out, gate_msa)
y_out = self.infer_phase_4(weights.blocks[block_idx].compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y_out, c_gate_msa)
if self.infer_conditional:
self.previous_residual_even = x - ori_x
else:
self.previous_residual_odd = x - ori_x
return x
def infer_using_cache(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if self.infer_conditional:
x += self.previous_residual_even
else:
x += self.previous_residual_odd
return x
def clear(self):
if self.previous_residual_even is not None:
self.previous_residual_even = self.previous_residual_even.cpu()
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu()
if self.previous_e0_even is not None:
self.previous_e0_even = self.previous_e0_even.cpu()
if self.previous_e0_odd is not None:
self.previous_e0_odd = self.previous_e0_odd.cpu()
self.previous_residual_even = None
self.previous_residual_odd = None
self.previous_e0_even = None
self.previous_e0_odd = None
torch.cuda.empty_cache()
class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.blocks_cache_even = [{} for _ in range(self.blocks_num)]
self.blocks_cache_odd = [{} for _ in range(self.blocks_num)]
# 1. get taylor step_diff when there is two caching_records in scheduler
def get_taylor_step_diff(self):
step_diff = 0
if self.infer_conditional:
current_step = self.scheduler.step_index
last_calc_step = current_step - 1
while last_calc_step >= 0 and not self.scheduler.caching_records[last_calc_step]:
last_calc_step -= 1
step_diff = current_step - last_calc_step
else:
current_step = self.scheduler.step_index
last_calc_step = current_step - 1
while last_calc_step >= 0 and not self.scheduler.caching_records_2[last_calc_step]:
last_calc_step -= 1
step_diff = current_step - last_calc_step
return step_diff
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if self.infer_conditional:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2
if caching_records_2[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
if self.config.enable_cfg:
self.switch_status()
return x
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
y_out = self.infer_phase_2(weights.blocks[block_idx].compute_phases[0], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
if self.infer_conditional:
self.derivative_approximation(self.blocks_cache_even[block_idx], "self_attn_out", y_out)
else:
self.derivative_approximation(self.blocks_cache_odd[block_idx], "self_attn_out", y_out)
attn_out = self.infer_phase_3(weights.blocks[block_idx].compute_phases[1], x, context, y_out, gate_msa)
if self.infer_conditional:
self.derivative_approximation(self.blocks_cache_even[block_idx], "cross_attn_out", attn_out)
else:
self.derivative_approximation(self.blocks_cache_odd[block_idx], "cross_attn_out", attn_out)
y_out = self.infer_phase_4(weights.blocks[block_idx].compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
if self.infer_conditional:
self.derivative_approximation(self.blocks_cache_even[block_idx], "ffn_out", y_out)
else:
self.derivative_approximation(self.blocks_cache_odd[block_idx], "ffn_out", y_out)
x = self.infer_phase_5(x, y_out, c_gate_msa)
return x
def infer_using_cache(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
x = self.infer_block(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx)
return x
# 1. taylor using caching
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, i):
# 1. shift, scale, gate
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# 2. residual and taylor
if self.infer_conditional:
out = self.taylor_formula(self.blocks_cache_even[i]["self_attn_out"])
out = out * gate_msa.squeeze(0)
x = x + out
out = self.taylor_formula(self.blocks_cache_even[i]["cross_attn_out"])
x = x + out
out = self.taylor_formula(self.blocks_cache_even[i]["ffn_out"])
out = out * c_gate_msa.squeeze(0)
x = x + out
else:
out = self.taylor_formula(self.blocks_cache_odd[i]["self_attn_out"])
out = out * gate_msa.squeeze(0)
x = x + out
out = self.taylor_formula(self.blocks_cache_odd[i]["cross_attn_out"])
x = x + out
out = self.taylor_formula(self.blocks_cache_odd[i]["ffn_out"])
out = out * c_gate_msa.squeeze(0)
x = x + out
return x
def clear(self):
for cache in self.blocks_cache_even:
for key in cache:
if cache[key] is not None:
cache[key] = cache[key].cpu()
cache.clear()
for cache in self.blocks_cache_odd:
for key in cache:
if cache[key] is not None:
cache[key] = cache[key].cpu()
cache.clear()
torch.cuda.empty_cache()
class WanTransformerInferAdaCaching(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
# 1. fixed args
self.decisive_double_block_id = self.blocks_num // 2
self.codebook = {0.03: 12, 0.05: 10, 0.07: 8, 0.09: 6, 0.11: 4, 1.00: 3}
# 2. Create two instances of AdaArgs
self.args_even = AdaArgs(config)
self.args_odd = AdaArgs(config)
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
modulated_inp = embed0 if self.scheduler.use_ret_steps else embed
if self.scheduler.cnt % 2 == 0: # even -> conditon
self.scheduler.is_even = True
if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps:
should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.scheduler.previous_e0_even.cuda()).abs().mean() / self.scheduler.previous_e0_even.cuda().abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance_even < self.scheduler.teacache_thresh:
should_calc_even = False
if self.infer_conditional:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# 1. calculate the skipped step length
if index <= self.scheduler.infer_steps - 2:
self.args_even.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.args_even.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records[index + i] = False
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records_2
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# 1. calculate the skipped step length
if index <= self.scheduler.infer_steps - 2:
self.args_odd.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.args_odd.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records_2[index + i] = False
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
if self.config.enable_cfg:
self.switch_status()
return x
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(self.blocks_num):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
y_out = self.infer_phase_2(weights.blocks[block_idx].compute_phases[0], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
if block_idx == self.decisive_double_block_id:
if self.infer_conditional:
self.args_even.now_residual_tiny = y_out * gate_msa.squeeze(0)
else:
self.args_odd.now_residual_tiny = y_out * gate_msa.squeeze(0)
attn_out = self.infer_phase_3(weights.blocks[block_idx].compute_phases[1], x, context, y_out, gate_msa)
y_out = self.infer_phase_4(weights.blocks[block_idx].compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y_out, c_gate_msa)
if self.infer_conditional:
self.args_even.previous_residual = x - ori_x
else:
self.args_odd.previous_residual = x - ori_x
return x
def infer_using_cache(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if self.infer_conditional:
x += self.args_even.previous_residual
else:
x += self.args_odd.previous_residual
return x
def calculate_skip_step_length(self):
if self.infer_conditional:
if self.args_even.previous_residual_tiny is None:
self.args_even.previous_residual_tiny = self.args_even.now_residual_tiny
return 1
else:
cache = self.args_even.previous_residual_tiny
res = self.args_even.now_residual_tiny
norm_ord = self.args_even.norm_ord
cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
cache_diff = cache_diff / self.args_even.skipped_step_length
if self.args_even.moreg_steps[0] <= self.scheduler.step_index <= self.args_even.moreg_steps[1]:
moreg = 0
for i in self.args_even.moreg_strides:
moreg_i = (res[i * self.args_even.spatial_dim :, :] - res[: -i * self.args_even.spatial_dim, :]).norm(p=norm_ord)
moreg_i /= res[i * self.args_even.spatial_dim :, :].norm(p=norm_ord) + res[: -i * self.args_even.spatial_dim, :].norm(p=norm_ord)
moreg += moreg_i
moreg = moreg / len(self.args_even.moreg_strides)
moreg = ((1 / self.args_even.moreg_hyp[0] * moreg) ** self.args_even.moreg_hyp[1]) / self.args_even.moreg_hyp[2]
else:
moreg = 1.0
mograd = self.args_even.mograd_mul * (moreg - self.args_even.previous_moreg) / self.args_even.skipped_step_length
self.args_even.previous_moreg = moreg
moreg = moreg + abs(mograd)
cache_diff = cache_diff * moreg
metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values())
if cache_diff < metric_thres[0]:
new_rate = cache_rates[0]
elif cache_diff < metric_thres[1]:
new_rate = cache_rates[1]
elif cache_diff < metric_thres[2]:
new_rate = cache_rates[2]
elif cache_diff < metric_thres[3]:
new_rate = cache_rates[3]
elif cache_diff < metric_thres[4]:
new_rate = cache_rates[4]
else:
new_rate = cache_rates[-1]
self.args_even.previous_residual_tiny = self.args_even.now_residual_tiny
return new_rate
else:
if self.args_odd.previous_residual_tiny is None:
self.args_odd.previous_residual_tiny = self.args_odd.now_residual_tiny
return 1
else:
cache = self.args_odd.previous_residual_tiny
res = self.args_odd.now_residual_tiny
norm_ord = self.args_odd.norm_ord
cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
cache_diff = cache_diff / self.args_odd.skipped_step_length
if self.args_odd.moreg_steps[0] <= self.scheduler.step_index <= self.args_odd.moreg_steps[1]:
moreg = 0
for i in self.args_odd.moreg_strides:
moreg_i = (res[i * self.args_odd.spatial_dim :, :] - res[: -i * self.args_odd.spatial_dim, :]).norm(p=norm_ord)
moreg_i /= res[i * self.args_odd.spatial_dim :, :].norm(p=norm_ord) + res[: -i * self.args_odd.spatial_dim, :].norm(p=norm_ord)
moreg += moreg_i
moreg = moreg / len(self.args_odd.moreg_strides)
moreg = ((1 / self.args_odd.moreg_hyp[0] * moreg) ** self.args_odd.moreg_hyp[1]) / self.args_odd.moreg_hyp[2]
else:
moreg = 1.0
mograd = self.args_odd.mograd_mul * (moreg - self.args_odd.previous_moreg) / self.args_odd.skipped_step_length
self.args_odd.previous_moreg = moreg
moreg = moreg + abs(mograd)
cache_diff = cache_diff * moreg
metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values())
if cache_diff < metric_thres[0]:
new_rate = cache_rates[0]
elif cache_diff < metric_thres[1]:
new_rate = cache_rates[1]
elif cache_diff < metric_thres[2]:
new_rate = cache_rates[2]
elif cache_diff < metric_thres[3]:
new_rate = cache_rates[3]
elif cache_diff < metric_thres[4]:
new_rate = cache_rates[4]
else:
new_rate = cache_rates[-1]
self.args_odd.previous_residual_tiny = self.args_odd.now_residual_tiny
return new_rate
def clear(self):
if self.args_even.previous_residual is not None:
self.args_even.previous_residual = self.args_even.previous_residual.cpu()
if self.args_even.previous_residual_tiny is not None:
self.args_even.previous_residual_tiny = self.args_even.previous_residual_tiny.cpu()
if self.args_even.now_residual_tiny is not None:
self.args_even.now_residual_tiny = self.args_even.now_residual_tiny.cpu()
if self.args_odd.previous_residual is not None:
self.args_odd.previous_residual = self.args_odd.previous_residual.cpu()
if self.args_odd.previous_residual_tiny is not None:
self.args_odd.previous_residual_tiny = self.args_odd.previous_residual_tiny.cpu()
if self.args_odd.now_residual_tiny is not None:
self.args_odd.now_residual_tiny = self.args_odd.now_residual_tiny.cpu()
self.args_even.previous_residual = None
self.args_even.previous_residual_tiny = None
self.args_even.now_residual_tiny = None
self.args_odd.previous_residual = None
self.args_odd.previous_residual_tiny = None
self.args_odd.now_residual_tiny = None
torch.cuda.empty_cache()
class AdaArgs:
def __init__(self, config):
# Cache related attributes
self.previous_residual_tiny = None
self.now_residual_tiny = None
self.norm_ord = 1
self.skipped_step_length = 1
self.previous_residual = None
# Moreg related attributes
self.previous_moreg = 1.0
self.moreg_strides = [1]
self.moreg_steps = [int(0.1 * config.infer_steps), int(0.9 * config.infer_steps)]
self.moreg_hyp = [0.385, 8, 1, 2]
self.mograd_mul = 10
self.spatial_dim = 1536
class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.teacache_thresh = config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None
self.previous_residual_even = None
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = None
self.previous_residual_odd = None
self.cache_even = {}
self.cache_odd = {}
self.use_ret_steps = config.use_ret_steps
self.set_attributes_by_task_and_model()
self.cnt = 0
# only in Wan2.1 TeaCaching
def set_attributes_by_task_and_model(self):
if self.config.task == "i2v":
if self.use_ret_steps:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
]
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
]
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
-114.36346466,
65.26524496,
-18.82220707,
4.91518089,
-0.23412683,
]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
elif self.config.task == "t2v":
if self.use_ret_steps:
if "1.3B" in self.config.model_path:
self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
if "14B" in self.config.model_path:
self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if "1.3B" in self.config.model_path:
self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
if "14B" in self.config.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
# 1. get taylor step_diff when there is two caching_records in scheduler
def get_taylor_step_diff(self):
step_diff = 0
if self.infer_conditional:
current_step = self.scheduler.step_index
last_calc_step = current_step - 1
while last_calc_step >= 0 and not self.scheduler.caching_records[last_calc_step]:
last_calc_step -= 1
step_diff = current_step - last_calc_step
else:
current_step = self.scheduler.step_index
last_calc_step = current_step - 1
while last_calc_step >= 0 and not self.scheduler.caching_records_2[last_calc_step]:
last_calc_step -= 1
step_diff = current_step - last_calc_step
return step_diff
# calculate should_calc
def calculate_should_calc(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
# 1. timestep embedding
modulated_inp = embed0 if self.use_ret_steps else embed
# 2. L1 calculate
should_calc = False
if self.infer_conditional:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc = True
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc = False
else:
should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0
self.scheduler.previous_e0_even = modulated_inp.clone()
if self.config["cpu_offload"]:
self.scheduler.previous_e0_even = self.scheduler.previous_e0_even.cpu()
modulated_inp = modulated_inp.cpu()
del modulated_inp
torch.cuda.empty_cache()
gc.collect()
else: # odd -> unconditon
self.scheduler.is_even = False
if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps:
should_calc_odd = True
self.scheduler.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_odd += rescale_func(
((modulated_inp - self.scheduler.previous_e0_odd.cuda()).abs().mean() / self.scheduler.previous_e0_odd.cuda().abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance_odd < self.scheduler.teacache_thresh:
should_calc_odd = False
should_calc = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()
else:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc = True
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc = False
else:
should_calc_odd = True
self.scheduler.accumulated_rel_l1_distance_odd = 0
self.scheduler.previous_e0_odd = modulated_inp.clone()
if self.config["cpu_offload"]:
self.scheduler.previous_e0_odd = self.scheduler.previous_e0_odd.cpu()
modulated_inp = modulated_inp.cpu()
del modulated_inp
torch.cuda.empty_cache()
gc.collect()
if self.scheduler.is_even:
if not should_calc_even:
x += self.scheduler.previous_residual_even.cuda()
else:
ori_x = x.clone()
x = super().infer(
weights,
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
self.scheduler.previous_residual_even = x - ori_x
if self.config["cpu_offload"]:
self.scheduler.previous_residual_even = self.scheduler.previous_residual_even.cpu()
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
gc.collect()
else:
if not should_calc_odd:
x += self.scheduler.previous_residual_odd.cuda()
else:
ori_x = x.clone()
x = super().infer(
weights,
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
self.scheduler.previous_residual_odd = x - ori_x
if self.config["cpu_offload"]:
self.scheduler.previous_residual_odd = self.scheduler.previous_residual_odd.cpu()
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
gc.collect()
should_calc = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
# 3. return the judgement
return should_calc
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if self.infer_conditional:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.scheduler.caching_records[index] = should_calc
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
if self.config.enable_cfg:
self.switch_status()
self.cnt += 1
return x
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(self.blocks_num):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
y_out = self.infer_phase_2(weights.blocks[block_idx].compute_phases[0], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa)
attn_out = self.infer_phase_3(weights.blocks[block_idx].compute_phases[1], x, context, y_out, gate_msa)
y_out = self.infer_phase_4(weights.blocks[block_idx].compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.infer_phase_5(x, y_out, c_gate_msa)
if self.infer_conditional:
self.previous_residual_even = x - ori_x
self.derivative_approximation(self.cache_even, "previous_residual", self.previous_residual_even)
else:
self.previous_residual_odd = x - ori_x
self.derivative_approximation(self.cache_odd, "previous_residual", self.previous_residual_odd)
return x
def infer_using_cache(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if self.infer_conditional:
x += self.taylor_formula(self.cache_even["previous_residual"])
else:
x += self.taylor_formula(self.cache_odd["previous_residual"])
return x
def clear(self):
if self.previous_residual_even is not None:
self.previous_residual_even = self.previous_residual_even.cpu()
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu()
if self.previous_e0_even is not None:
self.previous_e0_even = self.previous_e0_even.cpu()
if self.previous_e0_odd is not None:
self.previous_e0_odd = self.previous_e0_odd.cpu()
for key in self.cache_even:
if self.cache_even[key] is not None and hasattr(self.cache_even[key], "cpu"):
self.cache_even[key] = self.cache_even[key].cpu()
self.cache_even.clear()
for key in self.cache_odd:
if self.cache_odd[key] is not None and hasattr(self.cache_odd[key], "cpu"):
self.cache_odd[key] = self.cache_odd[key].cpu()
self.cache_odd.clear()
self.previous_residual_even = None
self.previous_residual_odd = None
self.previous_e0_even = None
self.previous_e0_odd = None
torch.cuda.empty_cache()
......@@ -4,10 +4,11 @@ from lightx2v.common.offload.manager import (
WeightAsyncStreamManager,
LazyWeightAsyncStreamManager,
)
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
class WanTransformerInfer:
class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
self.task = config["task"]
......@@ -49,8 +50,10 @@ class WanTransformerInfer:
else:
self.infer_func = self._infer_without_offload
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.infer_conditional = True
def switch_status(self):
self.infer_conditional = not self.infer_conditional
def _calculate_q_k_len(self, q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
......
......@@ -15,6 +15,9 @@ from lightx2v.models.networks.wan.infer.transformer_infer import (
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
WanTransformerInferTaylorCaching,
WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching,
)
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
......@@ -59,6 +62,12 @@ class WanModel:
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Taylor":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......@@ -201,10 +210,6 @@ class WanModel:
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
......@@ -212,11 +217,6 @@ class WanModel:
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
......
......@@ -9,6 +9,9 @@ from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTeaCaching,
WanSchedulerTaylorCaching,
WanSchedulerAdaCaching,
WanSchedulerCustomCaching,
)
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
......@@ -114,6 +117,12 @@ class WanRunner(DefaultRunner):
scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "Taylor":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
......
......@@ -7,7 +7,10 @@ class BaseScheduler:
self.config = config
self.step_index = 0
self.latents = None
self.infer_steps = config.infer_steps
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
def step_pre(self, step_index):
self.step_index = step_index
......
import torch
from ..scheduler import WanScheduler
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanSchedulerTeaCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.cnt = 0
self.num_steps = self.config.infer_steps * 2
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
self.use_ret_steps = self.config.use_ret_steps
if self.use_ret_steps:
self.coefficients = self.config.coefficients[0]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
self.coefficients = self.config.coefficients[1]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
def clear(self):
if self.previous_e0_even is not None:
self.previous_e0_even = self.previous_e0_even.cpu()
if self.previous_e0_odd is not None:
self.previous_e0_odd = self.previous_e0_odd.cpu()
if self.previous_residual_even is not None:
self.previous_residual_even = self.previous_residual_even.cpu()
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu()
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
torch.cuda.empty_cache()
self.transformer_infer.clear()
class WanSchedulerTaylorCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
pattern = [True, False, False, False]
self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
self.caching_records_2 = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
def clear(self):
self.transformer_infer.clear()
class WanSchedulerAdaCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerCustomCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
......@@ -18,6 +18,8 @@ class WanScheduler(BaseScheduler):
self.solver_order = 2
self.noise_pred = None
self.caching_records_2 = [True] * self.config.infer_steps
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment