Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
220a631f
Commit
220a631f
authored
Jun 30, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Jun 30, 2025
Browse files
update hunyuan cache (#79)
Co-authored-by:
Linboyan-trc
<
1584340372@qq.com
>
parent
9da774a7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
574 additions
and
277 deletions
+574
-277
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
...tworks/hunyuan/infer/feature_caching/transformer_infer.py
+524
-244
lightx2v/models/networks/hunyuan/model.py
lightx2v/models/networks/hunyuan/model.py
+10
-6
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+12
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+1
-1
lightx2v/models/runners/hunyuan/hunyuan_runner.py
lightx2v/models/runners/hunyuan/hunyuan_runner.py
+5
-1
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-1
lightx2v/models/schedulers/hunyuan/feature_caching/scheduler.py
...2v/models/schedulers/hunyuan/feature_caching/scheduler.py
+21
-21
lightx2v/models/schedulers/hunyuan/scheduler.py
lightx2v/models/schedulers/hunyuan/scheduler.py
+0
-1
No files found.
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
View file @
220a631f
from
..transformer_infer
import
HunyuanTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTaylorCachingTransformerInfer
import
torch
import
numpy
as
np
from
einops
import
rearrange
from
.utils
import
taylor_cache_init
,
derivative_approximation
,
taylor_formula
from
..utils_bf16
import
apply_rotary_emb
from
..transformer_infer
import
HunyuanTransformerInfer
class
HunyuanTransformerInferTeaCaching
(
HunyuanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
,
):
self
.
teacache_thresh
=
self
.
config
.
teacache_thresh
self
.
accumulated_rel_l1_distance
=
0
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
self
.
coefficients
=
[
7.33226126e02
,
-
4.01131952e02
,
6.75869174e01
,
-
3.14987800e00
,
9.61237896e-02
]
# 1. only in tea-cache, judge next step
def
calculate_should_calc
(
self
,
img
,
vec
,
weights
):
# 1. timestep embedding
inp
=
img
.
clone
()
vec_
=
vec
.
clone
()
img_mod1_shift
,
img_mod1_scale
,
_
,
_
,
_
,
_
=
weights
.
double_blocks
[
0
].
img_mod
.
apply
(
vec_
).
chunk
(
6
,
dim
=-
1
)
normed_inp
=
torch
.
nn
.
functional
.
layer_norm
(
inp
,
(
inp
.
shape
[
1
],),
None
,
None
,
1e-6
)
modulated_inp
=
normed_inp
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
del
normed_inp
,
inp
,
vec_
if
self
.
scheduler
.
cnt
==
0
or
self
.
scheduler
.
cnt
==
self
.
scheduler
.
num_steps
-
1
:
# 2. L1 calculate
if
self
.
scheduler
.
step_index
==
0
or
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
should_calc
=
True
self
.
scheduler
.
accumulated_rel_l1_distance
=
0
self
.
accumulated_rel_l1_distance
=
0
else
:
rescale_func
=
np
.
poly1d
(
self
.
scheduler
.
coefficients
)
self
.
scheduler
.
accumulated_rel_l1_distance
+=
rescale_func
(
((
modulated_inp
-
self
.
scheduler
.
previous_modulated_input
).
abs
().
mean
()
/
self
.
scheduler
.
previous_modulated_input
.
abs
().
mean
()).
cpu
().
item
()
)
if
self
.
scheduler
.
accumulated_rel_l1_distance
<
self
.
scheduler
.
teacache_thresh
:
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
self
.
accumulated_rel_l1_distance
+=
rescale_func
(((
modulated_inp
-
self
.
previous_modulated_input
).
abs
().
mean
()
/
self
.
previous_modulated_input
.
abs
().
mean
()).
cpu
().
item
())
if
self
.
accumulated_rel_l1_distance
<
self
.
teacache_thresh
:
should_calc
=
False
else
:
should_calc
=
True
self
.
scheduler
.
accumulated_rel_l1_distance
=
0
self
.
scheduler
.
previous_modulated_input
=
modulated_inp
self
.
accumulated_rel_l1_distance
=
0
self
.
previous_modulated_input
=
modulated_inp
del
modulated_inp
if
not
should_calc
:
img
+=
self
.
scheduler
.
previous_residual
# 3. return the judgement
return
should_calc
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
if
caching_records
[
index
]:
img
,
vec
=
self
.
infer_calculating
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
else
:
ori_img
=
img
.
clone
()
img
,
vec
=
super
().
infer
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
self
.
scheduler
.
previous_residual
=
img
-
ori_img
del
ori_img
torch
.
cuda
.
empty_cache
()
img
,
vec
=
self
.
infer_using_cache
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
index
<=
self
.
scheduler
.
infer_steps
-
2
:
should_calc
=
self
.
calculate_should_calc
(
img
,
vec
,
weights
)
self
.
scheduler
.
caching_records
[
index
+
1
]
=
should_calc
return
img
,
vec
def
infer_calculating
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
# 1. copy the noise
ori_img
=
img
.
clone
()
# 2. fully calculate
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
self
.
infer_double_block_phase_1
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
=
self
.
infer_double_block_phase_2
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
img
,
txt
=
self
.
infer_double_block_phase_3
(
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
out
,
mod_gate
,
tr_mod_gate
=
self
.
infer_single_block_phase_1
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block_phase_2
(
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
,
frist_frame_token_num
)
img
=
x
[:
img_seq_len
,
...]
# 3. cache the residual
self
.
previous_residual
=
img
-
ori_img
return
img
,
vec
def
infer_using_cache
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
img
+=
self
.
previous_residual
return
img
,
vec
def
clear
(
self
):
if
self
.
previous_residual
is
not
None
:
self
.
previous_residual
=
self
.
previous_residual
.
cpu
()
if
self
.
previous_modulated_input
is
not
None
:
self
.
previous_modulated_input
=
self
.
previous_modulated_input
.
cpu
()
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
torch
.
cuda
.
empty_cache
()
class
HunyuanTransformerInferTaylorCaching
(
HunyuanTransformerInfer
):
class
HunyuanTransformerInferTaylorCaching
(
HunyuanTransformerInfer
,
BaseTaylorCachingTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
assert
not
self
.
config
[
"cpu_offload"
],
"Not support cpu-offload for TaylorCaching"
self
.
double_blocks_cache
=
[{}
for
_
in
range
(
self
.
double_blocks_num
)]
self
.
single_blocks_cache
=
[{}
for
_
in
range
(
self
.
single_blocks_num
)]
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
if
caching_records
[
index
]:
return
self
.
infer_calculating
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
else
:
return
self
.
infer_using_cache
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
# 1. get taylor step_diff when there is only one caching_records in scheduler
def
get_taylor_step_diff
(
self
):
current_step
=
self
.
scheduler
.
step_index
last_calc_step
=
current_step
-
1
while
last_calc_step
>=
0
and
not
self
.
scheduler
.
caching_records
[
last_calc_step
]:
last_calc_step
-=
1
step_diff
=
current_step
-
last_calc_step
return
step_diff
def
infer_calculating
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
self
.
scheduler
.
current
[
"stream"
]
=
"double_stream"
for
i
in
range
(
self
.
double_blocks_num
):
self
.
scheduler
.
current
[
"layer"
]
=
i
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
self
.
infer_double_block_phase_1
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
self
.
derivative_approximation
(
self
.
double_blocks_cache
[
i
],
"img_attn"
,
img_out
)
self
.
derivative_approximation
(
self
.
double_blocks_cache
[
i
],
"txt_attn"
,
txt_out
)
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
=
self
.
infer_double_block_phase_2
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
self
.
derivative_approximation
(
self
.
double_blocks_cache
[
i
],
"img_mlp"
,
img_out
)
self
.
derivative_approximation
(
self
.
double_blocks_cache
[
i
],
"txt_mlp"
,
txt_out
)
img
,
txt
=
self
.
infer_double_block_phase_3
(
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
out
,
mod_gate
,
tr_mod_gate
=
self
.
infer_single_block_phase_1
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
self
.
derivative_approximation
(
self
.
single_blocks_cache
[
i
],
"total"
,
out
)
x
=
self
.
infer_single_block_phase_2
(
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
,
frist_frame_token_num
)
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
self
.
scheduler
.
current
[
"stream"
]
=
"single_stream"
def
infer_using_cache
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
i
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
self
.
scheduler
.
current
[
"layer"
]
=
i
x
=
self
.
infer_single_block
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
x
=
self
.
infer_single_block
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
i
)
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
# 1. taylor using caching
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
i
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
(
img_mod1_shift
,
img_mod1_scale
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
=
img_mod_out
.
chunk
(
6
,
dim
=-
1
)
img_mod1_shift
,
img_mod1_scale
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
=
img_mod_out
.
chunk
(
6
,
dim
=-
1
)
txt_mod_out
=
weights
.
txt_mod
.
apply
(
vec_silu
)
(
txt_mod1_shift
,
txt_mod1_scale
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
if
self
.
scheduler
.
current
[
"type"
]
==
"full"
:
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
None
,
None
,
None
,
freqs_cis
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
weights
.
double_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
# world_size = dist.get_world_size()
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img
=
self
.
infer_double_block_img_post_atten
(
weights
,
img
,
img_attn
,
txt_mod1_shift
,
txt_mod1_scale
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
out
=
self
.
taylor_formula
(
self
.
double_blocks_cache
[
i
][
"img_attn"
])
out
=
out
*
img_mod1_gate
img
=
img
+
out
out
=
self
.
taylor_formula
(
self
.
double_blocks_cache
[
i
][
"img_mlp"
])
out
=
out
*
img_mod2_gate
img
=
img
+
out
out
=
self
.
taylor_formula
(
self
.
double_blocks_cache
[
i
][
"txt_attn"
])
out
=
out
*
txt_mod1_gate
txt
=
txt
+
out
out
=
self
.
taylor_formula
(
self
.
double_blocks_cache
[
i
][
"txt_mlp"
])
out
=
out
*
txt_mod2_gate
txt
=
txt
+
out
return
img
,
txt
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
i
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
out
=
self
.
taylor_formula
(
self
.
single_blocks_cache
[
i
][
"total"
])
out
=
out
*
mod_gate
x
=
x
+
out
return
x
def
clear
(
self
):
for
cache
in
self
.
double_blocks_cache
:
for
key
in
cache
:
if
cache
[
key
]
is
not
None
:
if
isinstance
(
cache
[
key
],
torch
.
Tensor
):
cache
[
key
]
=
cache
[
key
].
cpu
()
elif
isinstance
(
cache
[
key
],
dict
):
for
k
,
v
in
cache
[
key
].
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
cache
[
key
][
k
]
=
v
.
cpu
()
cache
.
clear
()
for
cache
in
self
.
single_blocks_cache
:
for
key
in
cache
:
if
cache
[
key
]
is
not
None
:
if
isinstance
(
cache
[
key
],
torch
.
Tensor
):
cache
[
key
]
=
cache
[
key
].
cpu
()
elif
isinstance
(
cache
[
key
],
dict
):
for
k
,
v
in
cache
[
key
].
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
cache
[
key
][
k
]
=
v
.
cpu
()
cache
.
clear
()
torch
.
cuda
.
empty_cache
()
class
HunyuanTransformerInferAdaCaching
(
HunyuanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
# 1. fixed args
self
.
decisive_double_block_id
=
10
self
.
codebook
=
{
0.03
:
12
,
0.05
:
10
,
0.07
:
8
,
0.09
:
6
,
0.11
:
4
,
1.00
:
3
}
# 2. cache
self
.
previous_residual_tiny
=
None
self
.
now_residual_tiny
=
None
self
.
norm_ord
=
1
self
.
skipped_step_length
=
1
self
.
previous_residual
=
None
# 3. moreg
self
.
previous_moreg
=
1.0
self
.
moreg_strides
=
[
1
]
self
.
moreg_steps
=
[
int
(
0.1
*
config
.
infer_steps
),
int
(
0.9
*
config
.
infer_steps
)]
self
.
moreg_hyp
=
[
0.385
,
8
,
1
,
2
]
self
.
mograd_mul
=
10
self
.
spatial_dim
=
3072
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
if
caching_records
[
index
]:
img
,
vec
=
self
.
infer_calculating
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
# 3. calculate the skipped step length
if
index
<=
self
.
scheduler
.
infer_steps
-
2
:
self
.
skipped_step_length
=
self
.
calculate_skip_step_length
()
for
i
in
range
(
1
,
self
.
skipped_step_length
):
if
(
index
+
i
)
<=
self
.
scheduler
.
infer_steps
-
1
:
self
.
scheduler
.
caching_records
[
index
+
i
]
=
False
else
:
img
,
vec
=
self
.
infer_using_cache
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
return
img
,
vec
def
infer_calculating
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
ori_img
=
img
.
clone
()
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
txt
=
self
.
infer_double_block_txt_post_atten
(
weights
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
self
.
infer_double_block_phase_1
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
=
self
.
infer_double_block_phase_2
(
weights
.
double_blocks
[
i
],
img
,
txt
,
txt_attn
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
return
img
,
txt
if
i
==
self
.
decisive_double_block_id
:
self
.
now_residual_tiny
=
img_out
*
img_mod2_gate
img
,
txt
=
self
.
infer_double_block_phase_3
(
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
)
elif
self
.
scheduler
.
current
[
"type"
]
==
"taylor_cache"
:
self
.
scheduler
.
current
[
"module"
]
=
"img_attn"
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
out
,
mod_gate
,
tr_mod_gate
=
self
.
infer_single_block_phase_1
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block_phase_2
(
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
,
frist_frame_token_num
)
img
=
x
[:
img_seq_len
,
...]
out
=
taylor_formula
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
self
.
previous_residual
=
img
-
ori_img
out
=
out
*
img_mod1_gate
img
=
img
+
out
return
img
,
vec
self
.
scheduler
.
current
[
"module"
]
=
"img_mlp"
def
infer_using_cache
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
img
+=
self
.
previous_residual
return
img
,
vec
out
=
taylor_formula
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
# 1. ada's algorithm to calculate skip step length
def
calculate_skip_step_length
(
self
):
if
self
.
previous_residual_tiny
is
None
:
self
.
previous_residual_tiny
=
self
.
now_residual_tiny
return
1
else
:
cache
=
self
.
previous_residual_tiny
res
=
self
.
now_residual_tiny
norm_ord
=
self
.
norm_ord
cache_diff
=
(
cache
-
res
).
norm
(
dim
=
(
0
,
1
),
p
=
norm_ord
)
/
cache
.
norm
(
dim
=
(
0
,
1
),
p
=
norm_ord
)
cache_diff
=
cache_diff
/
self
.
skipped_step_length
if
self
.
moreg_steps
[
0
]
<=
self
.
scheduler
.
step_index
<=
self
.
moreg_steps
[
1
]:
moreg
=
0
for
i
in
self
.
moreg_strides
:
moreg_i
=
(
res
[
i
*
self
.
spatial_dim
:,
:]
-
res
[:
-
i
*
self
.
spatial_dim
,
:]).
norm
(
p
=
norm_ord
)
moreg_i
/=
res
[
i
*
self
.
spatial_dim
:,
:].
norm
(
p
=
norm_ord
)
+
res
[:
-
i
*
self
.
spatial_dim
,
:].
norm
(
p
=
norm_ord
)
moreg
+=
moreg_i
moreg
=
moreg
/
len
(
self
.
moreg_strides
)
moreg
=
((
1
/
self
.
moreg_hyp
[
0
]
*
moreg
)
**
self
.
moreg_hyp
[
1
])
/
self
.
moreg_hyp
[
2
]
else
:
moreg
=
1.0
mograd
=
self
.
mograd_mul
*
(
moreg
-
self
.
previous_moreg
)
/
self
.
skipped_step_length
self
.
previous_moreg
=
moreg
moreg
=
moreg
+
abs
(
mograd
)
cache_diff
=
cache_diff
*
moreg
metric_thres
,
cache_rates
=
list
(
self
.
codebook
.
keys
()),
list
(
self
.
codebook
.
values
())
if
cache_diff
<
metric_thres
[
0
]:
new_rate
=
cache_rates
[
0
]
elif
cache_diff
<
metric_thres
[
1
]:
new_rate
=
cache_rates
[
1
]
elif
cache_diff
<
metric_thres
[
2
]:
new_rate
=
cache_rates
[
2
]
elif
cache_diff
<
metric_thres
[
3
]:
new_rate
=
cache_rates
[
3
]
elif
cache_diff
<
metric_thres
[
4
]:
new_rate
=
cache_rates
[
4
]
else
:
new_rate
=
cache_rates
[
-
1
]
out
=
out
*
img_mod2_gate
img
=
img
+
out
self
.
previous_residual_tiny
=
self
.
now_residual_tiny
return
new_rate
self
.
scheduler
.
current
[
"module"
]
=
"txt_attn"
def
clear
(
self
):
if
self
.
previous_residual
is
not
None
:
self
.
previous_residual
=
self
.
previous_residual
.
cpu
()
if
self
.
previous_residual_tiny
is
not
None
:
self
.
previous_residual_tiny
=
self
.
previous_residual_tiny
.
cpu
()
if
self
.
now_residual_tiny
is
not
None
:
self
.
now_residual_tiny
=
self
.
now_residual_tiny
.
cpu
()
out
=
taylor_formula
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
self
.
previous_residual
=
None
self
.
previous_residual_tiny
=
None
self
.
now_residual_tiny
=
None
out
=
out
*
txt_mod1_gate
txt
=
txt
+
out
torch
.
cuda
.
empty_cache
()
self
.
scheduler
.
current
[
"module"
]
=
"txt_mlp"
out
=
out
*
txt_mod2_gate
txt
=
txt
+
out
class
HunyuanTransformerInferCustomCaching
(
HunyuanTransformerInfer
,
BaseTaylorCachingTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
teacache_thresh
=
self
.
config
.
teacache_thresh
self
.
accumulated_rel_l1_distance
=
0
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
self
.
coefficients
=
[
7.33226126e02
,
-
4.01131952e02
,
6.75869174e01
,
-
3.14987800e00
,
9.61237896e-02
]
return
img
,
txt
self
.
cache
=
{}
def
infer_double_block_img_post_atten
(
self
,
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
):
self
.
scheduler
.
current
[
"module"
]
=
"img_attn"
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
if
caching_records
[
index
]:
img
,
vec
=
self
.
infer_calculating
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
else
:
img
,
vec
=
self
.
infer_using_cache
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
out
=
out
*
img_mod1_gate
img
=
img
+
out
if
index
<=
self
.
scheduler
.
infer_steps
-
2
:
should_calc
=
self
.
calculate_should_calc
(
img
,
vec
,
weights
)
self
.
scheduler
.
caching_records
[
index
+
1
]
=
should_calc
self
.
scheduler
.
current
[
"module"
]
=
"img_mlp"
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
return
img
,
vec
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
out
=
weights
.
img_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
weights
.
img_mlp_fc2
.
apply
(
out
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
# 1. get taylor step_diff when there is only one caching_records in scheduler
def
get_taylor_step_diff
(
self
):
current_step
=
self
.
scheduler
.
step_index
last_calc_step
=
current_step
-
1
while
last_calc_step
>=
0
and
not
self
.
scheduler
.
caching_records
[
last_calc_step
]:
last_calc_step
-=
1
step_diff
=
current_step
-
last_calc_step
return
step_diff
# 1. only in tea-cache, judge next step
def
calculate_should_calc
(
self
,
img
,
vec
,
weights
):
# 1. timestep embedding
inp
=
img
.
clone
()
vec_
=
vec
.
clone
()
img_mod1_shift
,
img_mod1_scale
,
_
,
_
,
_
,
_
=
weights
.
double_blocks
[
0
].
img_mod
.
apply
(
vec_
).
chunk
(
6
,
dim
=-
1
)
normed_inp
=
torch
.
nn
.
functional
.
layer_norm
(
inp
,
(
inp
.
shape
[
1
],),
None
,
None
,
1e-6
)
modulated_inp
=
normed_inp
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
del
normed_inp
,
inp
,
vec_
out
=
out
*
img_mod2_gate
img
=
img
+
out
return
img
def
infer_double_block_txt_post_atten
(
self
,
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
):
self
.
scheduler
.
current
[
"module"
]
=
"txt_attn"
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
# 2. L1 calculate
if
self
.
scheduler
.
step_index
==
0
or
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
should_calc
=
True
self
.
accumulated_rel_l1_distance
=
0
else
:
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
self
.
accumulated_rel_l1_distance
+=
rescale_func
(((
modulated_inp
-
self
.
previous_modulated_input
).
abs
().
mean
()
/
self
.
previous_modulated_input
.
abs
().
mean
()).
cpu
().
item
())
if
self
.
accumulated_rel_l1_distance
<
self
.
teacache_thresh
:
should_calc
=
False
else
:
should_calc
=
True
self
.
accumulated_rel_l1_distance
=
0
self
.
previous_modulated_input
=
modulated_inp
del
modulated_inp
out
=
out
*
txt_mod1_gate
txt
=
txt
+
out
# 3. return the judgement
return
should_calc
self
.
scheduler
.
current
[
"module"
]
=
"txt_mlp"
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
def
infer_calculating
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
ori_img
=
img
.
clone
(
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
txt_mod2_scale
)
+
txt_mod2_shift
out
=
weights
.
txt_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
weights
.
txt_mlp_fc2
.
apply
(
out
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
self
.
infer_double_block_phase_1
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
=
self
.
infer_double_block_phase_2
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
img
,
txt
=
self
.
infer_double_block_phase_3
(
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
)
out
=
out
*
txt_mod2_gate
txt
=
txt
+
out
return
txt
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
out
,
mod_gate
,
tr_mod_gate
=
self
.
infer_single_block_phase_1
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block_phase_2
(
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
,
frist_frame_token_num
)
img
=
x
[:
img_seq_len
,
...]
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
self
.
previous_residual
=
img
-
ori_img
self
.
derivative_approximation
(
self
.
cache
,
"previous_residual"
,
self
.
previous_residual
)
if
self
.
scheduler
.
current
[
"type"
]
==
"full"
:
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
x_mod
=
out
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
qkv
,
mlp
=
torch
.
split
(
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
self
.
scheduler
.
current
[
"module"
]
=
"attn"
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
q
=
weights
.
q_norm
.
apply
(
q
)
k
=
weights
.
k_norm
.
apply
(
k
)
img_q
,
txt_q
=
q
[:
-
txt_seq_len
,
:,
:],
q
[
-
txt_seq_len
:,
:,
:]
img_k
,
txt_k
=
k
[:
-
txt_seq_len
,
:,
:],
k
[
-
txt_seq_len
:,
:,
:]
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
weights
.
single_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
attn
)
self
.
scheduler
.
current
[
"module"
]
=
"total"
taylor_cache_init
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
"tanh"
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
weights
.
linear2
.
apply
(
out
)
derivative_approximation
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
,
out
)
out
=
out
*
mod_gate
x
=
x
+
out
return
x
elif
self
.
scheduler
.
current
[
"type"
]
==
"taylor_cache"
:
self
.
scheduler
.
current
[
"module"
]
=
"total"
out
=
taylor_formula
(
self
.
scheduler
.
cache_dic
,
self
.
scheduler
.
current
)
out
=
out
*
mod_gate
x
=
x
+
out
return
x
return
img
,
vec
def
infer_using_cache
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
img
+=
self
.
taylor_formula
(
self
.
cache
[
"previous_residual"
])
return
img
,
vec
def
clear
(
self
):
if
self
.
previous_residual
is
not
None
:
self
.
previous_residual
=
self
.
previous_residual
.
cpu
()
if
self
.
previous_modulated_input
is
not
None
:
self
.
previous_modulated_input
=
self
.
previous_modulated_input
.
cpu
()
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
torch
.
cuda
.
empty_cache
()
lightx2v/models/networks/hunyuan/model.py
View file @
220a631f
...
...
@@ -7,8 +7,12 @@ from lightx2v.models.networks.hunyuan.weights.transformer_weights import Hunyuan
from
lightx2v.models.networks.hunyuan.infer.pre_infer
import
HunyuanPreInfer
from
lightx2v.models.networks.hunyuan.infer.post_infer
import
HunyuanPostInfer
from
lightx2v.models.networks.hunyuan.infer.transformer_infer
import
HunyuanTransformerInfer
from
lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer
import
HunyuanTransformerInferTaylorCaching
,
HunyuanTransformerInferTeaCaching
from
lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer
import
(
HunyuanTransformerInferTaylorCaching
,
HunyuanTransformerInferTeaCaching
,
HunyuanTransformerInferAdaCaching
,
HunyuanTransformerInferCustomCaching
,
)
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
from
lightx2v.utils.envs
import
*
...
...
@@ -156,10 +160,6 @@ class HunyuanModel:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
==
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
HunyuanPreInfer
...
...
@@ -170,5 +170,9 @@ class HunyuanModel:
self
.
transformer_infer_class
=
HunyuanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferCustomCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
220a631f
...
...
@@ -305,13 +305,23 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
for
cache
in
self
.
blocks_cache_even
:
for
key
in
cache
:
if
cache
[
key
]
is
not
None
:
cache
[
key
]
=
cache
[
key
].
cpu
()
if
isinstance
(
cache
[
key
],
torch
.
Tensor
):
cache
[
key
]
=
cache
[
key
].
cpu
()
elif
isinstance
(
cache
[
key
],
dict
):
for
k
,
v
in
cache
[
key
].
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
cache
[
key
][
k
]
=
v
.
cpu
()
cache
.
clear
()
for
cache
in
self
.
blocks_cache_odd
:
for
key
in
cache
:
if
cache
[
key
]
is
not
None
:
cache
[
key
]
=
cache
[
key
].
cpu
()
if
isinstance
(
cache
[
key
],
torch
.
Tensor
):
cache
[
key
]
=
cache
[
key
].
cpu
()
elif
isinstance
(
cache
[
key
],
dict
):
for
k
,
v
in
cache
[
key
].
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
cache
[
key
][
k
]
=
v
.
cpu
()
cache
.
clear
()
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/networks/wan/model.py
View file @
220a631f
...
...
@@ -62,7 +62,7 @@ class WanModel:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Taylor"
:
elif
self
.
config
[
"feature_caching"
]
==
"Taylor
Seer
"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
...
...
lightx2v/models/runners/hunyuan/hunyuan_runner.py
View file @
220a631f
...
...
@@ -6,7 +6,7 @@ from PIL import Image
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.hunyuan.scheduler
import
HunyuanScheduler
from
lightx2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerTaylorCaching
,
HunyuanSchedulerTeaCaching
from
lightx2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerTaylorCaching
,
HunyuanSchedulerTeaCaching
,
HunyuanSchedulerAdaCaching
,
HunyuanSchedulerCustomCaching
from
lightx2v.models.input_encoders.hf.llama.model
import
TextEncoderHFLlamaModel
from
lightx2v.models.input_encoders.hf.clip.model
import
TextEncoderHFClipModel
from
lightx2v.models.input_encoders.hf.llava.model
import
TextEncoderHFLlavaModel
...
...
@@ -47,6 +47,10 @@ class HunyuanRunner(DefaultRunner):
scheduler
=
HunyuanSchedulerTeaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
scheduler
=
HunyuanSchedulerTaylorCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Ada"
:
scheduler
=
HunyuanSchedulerAdaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Custom"
:
scheduler
=
HunyuanSchedulerCustomCaching
(
self
.
config
)
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
self
.
model
.
set_scheduler
(
scheduler
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
220a631f
...
...
@@ -117,7 +117,7 @@ class WanRunner(DefaultRunner):
scheduler
=
WanScheduler
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Tea"
:
scheduler
=
WanSchedulerTeaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Taylor"
:
elif
self
.
config
.
feature_caching
==
"Taylor
Seer
"
:
scheduler
=
WanSchedulerTaylorCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Ada"
:
scheduler
=
WanSchedulerAdaCaching
(
self
.
config
)
...
...
lightx2v/models/schedulers/hunyuan/feature_caching/scheduler.py
View file @
220a631f
from
.utils
import
cache_init
,
cal_type
from
..scheduler
import
HunyuanScheduler
import
torch
...
...
@@ -6,31 +5,32 @@ import torch
class
HunyuanSchedulerTeaCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
num_steps
=
self
.
config
.
infer_steps
self
.
teacache_thresh
=
self
.
config
.
teacache_thresh
self
.
accumulated_rel_l1_distance
=
0
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
self
.
coefficients
=
[
7.33226126e02
,
-
4.01131952e02
,
6.75869174e01
,
-
3.14987800e00
,
9.61237896e-02
]
def
clear
(
self
):
if
self
.
previous_residual
is
not
None
:
self
.
previous_residual
=
self
.
previous_residual
.
cpu
()
if
self
.
previous_modulated_input
is
not
None
:
self
.
previous_modulated_input
=
self
.
previous_modulated_input
.
cpu
()
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
torch
.
cuda
.
empty_cache
()
self
.
transformer_infer
.
clear
()
class
HunyuanSchedulerTaylorCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
cache_dic
,
self
.
current
=
cache_init
(
self
.
infer_steps
)
pattern
=
[
True
,
False
,
False
,
False
]
self
.
caching_records
=
(
pattern
*
((
config
.
infer_steps
+
3
)
//
4
))[:
config
.
infer_steps
]
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
HunyuanSchedulerAdaCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
step_pre
(
self
,
step_index
):
super
().
step_pre
(
step_index
)
self
.
current
[
"step"
]
=
step_index
cal_type
(
self
.
cache_dic
,
self
.
current
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
HunyuanSchedulerCustomCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
lightx2v/models/schedulers/hunyuan/scheduler.py
View file @
220a631f
...
...
@@ -237,7 +237,6 @@ def get_1d_rotary_pos_embed_riflex(
class
HunyuanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
shift
=
7.0
self
.
timesteps
,
self
.
sigmas
=
set_timesteps_sigmas
(
self
.
infer_steps
,
self
.
shift
,
device
=
torch
.
device
(
"cuda"
))
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment