Commit 220a631f authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

update hunyuan cache (#79)


Co-authored-by: default avatarLinboyan-trc <1584340372@qq.com>
parent 9da774a7
from ..transformer_infer import HunyuanTransformerInfer
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
import torch
import numpy as np
from einops import rearrange
from .utils import taylor_cache_init, derivative_approximation, taylor_formula
from ..utils_bf16 import apply_rotary_emb
from ..transformer_infer import HunyuanTransformerInfer
class HunyuanTransformerInferTeaCaching(HunyuanTransformerInfer):
def __init__(self, config):
super().__init__(config)
def infer(
self,
weights,
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec=None,
frist_frame_token_num=None,
):
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.coefficients = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
# 1. only in tea-cache, judge next step
def calculate_should_calc(self, img, vec, weights):
# 1. timestep embedding
inp = img.clone()
vec_ = vec.clone()
img_mod1_shift, img_mod1_scale, _, _, _, _ = weights.double_blocks[0].img_mod.apply(vec_).chunk(6, dim=-1)
normed_inp = torch.nn.functional.layer_norm(inp, (inp.shape[1],), None, None, 1e-6)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
del normed_inp, inp, vec_
if self.scheduler.cnt == 0 or self.scheduler.cnt == self.scheduler.num_steps - 1:
# 2. L1 calculate
if self.scheduler.step_index == 0 or self.scheduler.step_index == self.scheduler.infer_steps - 1:
should_calc = True
self.scheduler.accumulated_rel_l1_distance = 0
self.accumulated_rel_l1_distance = 0
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance += rescale_func(
((modulated_inp - self.scheduler.previous_modulated_input).abs().mean() / self.scheduler.previous_modulated_input.abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance < self.scheduler.teacache_thresh:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.scheduler.accumulated_rel_l1_distance = 0
self.scheduler.previous_modulated_input = modulated_inp
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
del modulated_inp
if not should_calc:
img += self.scheduler.previous_residual
# 3. return the judgement
return should_calc
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if caching_records[index]:
img, vec = self.infer_calculating(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
else:
ori_img = img.clone()
img, vec = super().infer(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
self.scheduler.previous_residual = img - ori_img
del ori_img
torch.cuda.empty_cache()
img, vec = self.infer_using_cache(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
if index <= self.scheduler.infer_steps - 2:
should_calc = self.calculate_should_calc(img, vec, weights)
self.scheduler.caching_records[index + 1] = should_calc
return img, vec
def infer_calculating(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
# 1. copy the noise
ori_img = img.clone()
# 2. fully calculate
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
(
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.infer_double_block_phase_1(weights.double_blocks[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img, txt, img_out, txt_out, img_mod2_gate, txt_mod2_gate = self.infer_double_block_phase_2(
weights.double_blocks[i],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec,
frist_frame_token_num,
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
img, txt = self.infer_double_block_phase_3(img_out, img_mod2_gate, img, txt_out, txt_mod2_gate, txt)
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
out, mod_gate, tr_mod_gate = self.infer_single_block_phase_1(
weights.single_blocks[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num
)
x = self.infer_single_block_phase_2(x, out, tr_mod_gate, mod_gate, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...]
# 3. cache the residual
self.previous_residual = img - ori_img
return img, vec
def infer_using_cache(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
img += self.previous_residual
return img, vec
def clear(self):
if self.previous_residual is not None:
self.previous_residual = self.previous_residual.cpu()
if self.previous_modulated_input is not None:
self.previous_modulated_input = self.previous_modulated_input.cpu()
self.previous_modulated_input = None
self.previous_residual = None
torch.cuda.empty_cache()
class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer):
class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer, BaseTaylorCachingTransformerInfer):
def __init__(self, config):
super().__init__(config)
assert not self.config["cpu_offload"], "Not support cpu-offload for TaylorCaching"
self.double_blocks_cache = [{} for _ in range(self.double_blocks_num)]
self.single_blocks_cache = [{} for _ in range(self.single_blocks_num)]
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if caching_records[index]:
return self.infer_calculating(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
else:
return self.infer_using_cache(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
# 1. get taylor step_diff when there is only one caching_records in scheduler
def get_taylor_step_diff(self):
current_step = self.scheduler.step_index
last_calc_step = current_step - 1
while last_calc_step >= 0 and not self.scheduler.caching_records[last_calc_step]:
last_calc_step -= 1
step_diff = current_step - last_calc_step
return step_diff
def infer_calculating(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
self.scheduler.current["stream"] = "double_stream"
for i in range(self.double_blocks_num):
self.scheduler.current["layer"] = i
img, txt = self.infer_double_block(weights.double_blocks[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
(
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.infer_double_block_phase_1(weights.double_blocks[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
self.derivative_approximation(self.double_blocks_cache[i], "img_attn", img_out)
self.derivative_approximation(self.double_blocks_cache[i], "txt_attn", txt_out)
img, txt, img_out, txt_out, img_mod2_gate, txt_mod2_gate = self.infer_double_block_phase_2(
weights.double_blocks[i],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec,
frist_frame_token_num,
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
self.derivative_approximation(self.double_blocks_cache[i], "img_mlp", img_out)
self.derivative_approximation(self.double_blocks_cache[i], "txt_mlp", txt_out)
img, txt = self.infer_double_block_phase_3(img_out, img_mod2_gate, img, txt_out, txt_mod2_gate, txt)
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
out, mod_gate, tr_mod_gate = self.infer_single_block_phase_1(
weights.single_blocks[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num
)
self.derivative_approximation(self.single_blocks_cache[i], "total", out)
x = self.infer_single_block_phase_2(x, out, tr_mod_gate, mod_gate, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...]
return img, vec
self.scheduler.current["stream"] = "single_stream"
def infer_using_cache(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
img, txt = self.infer_double_block(weights.double_blocks[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, i)
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
self.scheduler.current["layer"] = i
x = self.infer_single_block(weights.single_blocks[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
x = self.infer_single_block(weights.single_blocks[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, i)
img = x[:img_seq_len, ...]
return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
# 1. taylor using caching
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, i):
vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu)
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = img_mod_out.chunk(6, dim=-1)
img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod_out.chunk(6, dim=-1)
txt_mod_out = weights.txt_mod.apply(vec_silu)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1)
if self.scheduler.current["type"] == "full":
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, None, None, None, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
if not self.parallel_attention:
attn = weights.double_attn.apply(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
else:
# world_size = dist.get_world_size()
attn = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(
weights,
img,
img_attn,
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod_out.chunk(6, dim=-1)
out = self.taylor_formula(self.double_blocks_cache[i]["img_attn"])
out = out * img_mod1_gate
img = img + out
out = self.taylor_formula(self.double_blocks_cache[i]["img_mlp"])
out = out * img_mod2_gate
img = img + out
out = self.taylor_formula(self.double_blocks_cache[i]["txt_attn"])
out = out * txt_mod1_gate
txt = txt + out
out = self.taylor_formula(self.double_blocks_cache[i]["txt_mlp"])
out = out * txt_mod2_gate
txt = txt + out
return img, txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, i):
out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
out = self.taylor_formula(self.single_blocks_cache[i]["total"])
out = out * mod_gate
x = x + out
return x
def clear(self):
for cache in self.double_blocks_cache:
for key in cache:
if cache[key] is not None:
if isinstance(cache[key], torch.Tensor):
cache[key] = cache[key].cpu()
elif isinstance(cache[key], dict):
for k, v in cache[key].items():
if isinstance(v, torch.Tensor):
cache[key][k] = v.cpu()
cache.clear()
for cache in self.single_blocks_cache:
for key in cache:
if cache[key] is not None:
if isinstance(cache[key], torch.Tensor):
cache[key] = cache[key].cpu()
elif isinstance(cache[key], dict):
for k, v in cache[key].items():
if isinstance(v, torch.Tensor):
cache[key][k] = v.cpu()
cache.clear()
torch.cuda.empty_cache()
class HunyuanTransformerInferAdaCaching(HunyuanTransformerInfer):
def __init__(self, config):
super().__init__(config)
# 1. fixed args
self.decisive_double_block_id = 10
self.codebook = {0.03: 12, 0.05: 10, 0.07: 8, 0.09: 6, 0.11: 4, 1.00: 3}
# 2. cache
self.previous_residual_tiny = None
self.now_residual_tiny = None
self.norm_ord = 1
self.skipped_step_length = 1
self.previous_residual = None
# 3. moreg
self.previous_moreg = 1.0
self.moreg_strides = [1]
self.moreg_steps = [int(0.1 * config.infer_steps), int(0.9 * config.infer_steps)]
self.moreg_hyp = [0.385, 8, 1, 2]
self.mograd_mul = 10
self.spatial_dim = 3072
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if caching_records[index]:
img, vec = self.infer_calculating(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
# 3. calculate the skipped step length
if index <= self.scheduler.infer_steps - 2:
self.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records[index + i] = False
else:
img, vec = self.infer_using_cache(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
return img, vec
def infer_calculating(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
ori_img = img.clone()
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
(
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
)
txt = self.infer_double_block_txt_post_atten(
weights,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.infer_double_block_phase_1(weights.double_blocks[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img, txt, img_out, txt_out, img_mod2_gate, txt_mod2_gate = self.infer_double_block_phase_2(
weights.double_blocks[i],
img,
txt,
txt_attn,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec,
frist_frame_token_num,
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
return img, txt
if i == self.decisive_double_block_id:
self.now_residual_tiny = img_out * img_mod2_gate
img, txt = self.infer_double_block_phase_3(img_out, img_mod2_gate, img, txt_out, txt_mod2_gate, txt)
elif self.scheduler.current["type"] == "taylor_cache":
self.scheduler.current["module"] = "img_attn"
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
out, mod_gate, tr_mod_gate = self.infer_single_block_phase_1(
weights.single_blocks[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num
)
x = self.infer_single_block_phase_2(x, out, tr_mod_gate, mod_gate, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...]
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
self.previous_residual = img - ori_img
out = out * img_mod1_gate
img = img + out
return img, vec
self.scheduler.current["module"] = "img_mlp"
def infer_using_cache(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
img += self.previous_residual
return img, vec
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
# 1. ada's algorithm to calculate skip step length
def calculate_skip_step_length(self):
if self.previous_residual_tiny is None:
self.previous_residual_tiny = self.now_residual_tiny
return 1
else:
cache = self.previous_residual_tiny
res = self.now_residual_tiny
norm_ord = self.norm_ord
cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
cache_diff = cache_diff / self.skipped_step_length
if self.moreg_steps[0] <= self.scheduler.step_index <= self.moreg_steps[1]:
moreg = 0
for i in self.moreg_strides:
moreg_i = (res[i * self.spatial_dim :, :] - res[: -i * self.spatial_dim, :]).norm(p=norm_ord)
moreg_i /= res[i * self.spatial_dim :, :].norm(p=norm_ord) + res[: -i * self.spatial_dim, :].norm(p=norm_ord)
moreg += moreg_i
moreg = moreg / len(self.moreg_strides)
moreg = ((1 / self.moreg_hyp[0] * moreg) ** self.moreg_hyp[1]) / self.moreg_hyp[2]
else:
moreg = 1.0
mograd = self.mograd_mul * (moreg - self.previous_moreg) / self.skipped_step_length
self.previous_moreg = moreg
moreg = moreg + abs(mograd)
cache_diff = cache_diff * moreg
metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values())
if cache_diff < metric_thres[0]:
new_rate = cache_rates[0]
elif cache_diff < metric_thres[1]:
new_rate = cache_rates[1]
elif cache_diff < metric_thres[2]:
new_rate = cache_rates[2]
elif cache_diff < metric_thres[3]:
new_rate = cache_rates[3]
elif cache_diff < metric_thres[4]:
new_rate = cache_rates[4]
else:
new_rate = cache_rates[-1]
out = out * img_mod2_gate
img = img + out
self.previous_residual_tiny = self.now_residual_tiny
return new_rate
self.scheduler.current["module"] = "txt_attn"
def clear(self):
if self.previous_residual is not None:
self.previous_residual = self.previous_residual.cpu()
if self.previous_residual_tiny is not None:
self.previous_residual_tiny = self.previous_residual_tiny.cpu()
if self.now_residual_tiny is not None:
self.now_residual_tiny = self.now_residual_tiny.cpu()
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
self.previous_residual = None
self.previous_residual_tiny = None
self.now_residual_tiny = None
out = out * txt_mod1_gate
txt = txt + out
torch.cuda.empty_cache()
self.scheduler.current["module"] = "txt_mlp"
out = out * txt_mod2_gate
txt = txt + out
class HunyuanTransformerInferCustomCaching(HunyuanTransformerInfer, BaseTaylorCachingTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.coefficients = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
return img, txt
self.cache = {}
def infer_double_block_img_post_atten(
self,
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
):
self.scheduler.current["module"] = "img_attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
out = weights.img_attn_proj.apply(img_attn)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
if caching_records[index]:
img, vec = self.infer_calculating(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
else:
img, vec = self.infer_using_cache(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
out = out * img_mod1_gate
img = img + out
if index <= self.scheduler.infer_steps - 2:
should_calc = self.calculate_should_calc(img, vec, weights)
self.scheduler.caching_records[index + 1] = should_calc
self.scheduler.current["module"] = "img_mlp"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
return img, vec
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + img_mod2_scale) + img_mod2_shift
out = weights.img_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.img_mlp_fc2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
# 1. get taylor step_diff when there is only one caching_records in scheduler
def get_taylor_step_diff(self):
current_step = self.scheduler.step_index
last_calc_step = current_step - 1
while last_calc_step >= 0 and not self.scheduler.caching_records[last_calc_step]:
last_calc_step -= 1
step_diff = current_step - last_calc_step
return step_diff
# 1. only in tea-cache, judge next step
def calculate_should_calc(self, img, vec, weights):
# 1. timestep embedding
inp = img.clone()
vec_ = vec.clone()
img_mod1_shift, img_mod1_scale, _, _, _, _ = weights.double_blocks[0].img_mod.apply(vec_).chunk(6, dim=-1)
normed_inp = torch.nn.functional.layer_norm(inp, (inp.shape[1],), None, None, 1e-6)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
del normed_inp, inp, vec_
out = out * img_mod2_gate
img = img + out
return img
def infer_double_block_txt_post_atten(
self,
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
):
self.scheduler.current["module"] = "txt_attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = weights.txt_attn_proj.apply(txt_attn)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
# 2. L1 calculate
if self.scheduler.step_index == 0 or self.scheduler.step_index == self.scheduler.infer_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
del modulated_inp
out = out * txt_mod1_gate
txt = txt + out
# 3. return the judgement
return should_calc
self.scheduler.current["module"] = "txt_mlp"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
def infer_calculating(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
ori_img = img.clone()
out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
out = out * (1 + txt_mod2_scale) + txt_mod2_shift
out = weights.txt_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.txt_mlp_fc2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
(
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.infer_double_block_phase_1(weights.double_blocks[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img, txt, img_out, txt_out, img_mod2_gate, txt_mod2_gate = self.infer_double_block_phase_2(
weights.double_blocks[i],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec,
frist_frame_token_num,
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
img, txt = self.infer_double_block_phase_3(img_out, img_mod2_gate, img, txt_out, txt_mod2_gate, txt)
out = out * txt_mod2_gate
txt = txt + out
return txt
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
out, mod_gate, tr_mod_gate = self.infer_single_block_phase_1(
weights.single_blocks[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num
)
x = self.infer_single_block_phase_2(x, out, tr_mod_gate, mod_gate, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...]
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
self.previous_residual = img - ori_img
self.derivative_approximation(self.cache, "previous_residual", self.previous_residual)
if self.scheduler.current["type"] == "full":
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x_mod = out * (1 + mod_scale) + mod_shift
x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
self.scheduler.current["module"] = "attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
q = weights.q_norm.apply(q)
k = weights.k_norm.apply(k)
img_q, txt_q = q[:-txt_seq_len, :, :], q[-txt_seq_len:, :, :]
img_k, txt_k = k[:-txt_seq_len, :, :], k[-txt_seq_len:, :, :]
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
if not self.parallel_attention:
attn = weights.single_attn.apply(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
else:
attn = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, attn)
self.scheduler.current["module"] = "total"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = torch.nn.functional.gelu(mlp, approximate="tanh")
out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * mod_gate
x = x + out
return x
elif self.scheduler.current["type"] == "taylor_cache":
self.scheduler.current["module"] = "total"
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * mod_gate
x = x + out
return x
return img, vec
def infer_using_cache(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
img += self.taylor_formula(self.cache["previous_residual"])
return img, vec
def clear(self):
if self.previous_residual is not None:
self.previous_residual = self.previous_residual.cpu()
if self.previous_modulated_input is not None:
self.previous_modulated_input = self.previous_modulated_input.cpu()
self.previous_modulated_input = None
self.previous_residual = None
torch.cuda.empty_cache()
......@@ -7,8 +7,12 @@ from lightx2v.models.networks.hunyuan.weights.transformer_weights import Hunyuan
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferTaylorCaching, HunyuanTransformerInferTeaCaching
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import (
HunyuanTransformerInferTaylorCaching,
HunyuanTransformerInferTeaCaching,
HunyuanTransformerInferAdaCaching,
HunyuanTransformerInferCustomCaching,
)
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
......@@ -156,10 +160,6 @@ class HunyuanModel:
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps:
self.scheduler.cnt = 0
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
......@@ -170,5 +170,9 @@ class HunyuanModel:
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = HunyuanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = HunyuanTransformerInferCustomCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......@@ -305,13 +305,23 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
for cache in self.blocks_cache_even:
for key in cache:
if cache[key] is not None:
cache[key] = cache[key].cpu()
if isinstance(cache[key], torch.Tensor):
cache[key] = cache[key].cpu()
elif isinstance(cache[key], dict):
for k, v in cache[key].items():
if isinstance(v, torch.Tensor):
cache[key][k] = v.cpu()
cache.clear()
for cache in self.blocks_cache_odd:
for key in cache:
if cache[key] is not None:
cache[key] = cache[key].cpu()
if isinstance(cache[key], torch.Tensor):
cache[key] = cache[key].cpu()
elif isinstance(cache[key], dict):
for k, v in cache[key].items():
if isinstance(v, torch.Tensor):
cache[key][k] = v.cpu()
cache.clear()
torch.cuda.empty_cache()
......
......@@ -62,7 +62,7 @@ class WanModel:
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Taylor":
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
......
......@@ -6,7 +6,7 @@ from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching, HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
......@@ -47,6 +47,10 @@ class HunyuanRunner(DefaultRunner):
scheduler = HunyuanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = HunyuanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = HunyuanSchedulerCustomCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
......
......@@ -117,7 +117,7 @@ class WanRunner(DefaultRunner):
scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "Taylor":
elif self.config.feature_caching == "TaylorSeer":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
......
from .utils import cache_init, cal_type
from ..scheduler import HunyuanScheduler
import torch
......@@ -6,31 +5,32 @@ import torch
class HunyuanSchedulerTeaCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
self.cnt = 0
self.num_steps = self.config.infer_steps
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.coefficients = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
def clear(self):
if self.previous_residual is not None:
self.previous_residual = self.previous_residual.cpu()
if self.previous_modulated_input is not None:
self.previous_modulated_input = self.previous_modulated_input.cpu()
self.previous_modulated_input = None
self.previous_residual = None
torch.cuda.empty_cache()
self.transformer_infer.clear()
class HunyuanSchedulerTaylorCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
self.cache_dic, self.current = cache_init(self.infer_steps)
pattern = [True, False, False, False]
self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
def clear(self):
self.transformer_infer.clear()
class HunyuanSchedulerAdaCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def step_pre(self, step_index):
super().step_pre(step_index)
self.current["step"] = step_index
cal_type(self.cache_dic, self.current)
def clear(self):
self.transformer_infer.clear()
class HunyuanSchedulerCustomCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
......@@ -237,7 +237,6 @@ def get_1d_rotary_pos_embed_riflex(
class HunyuanScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.infer_steps = self.config.infer_steps
self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment