Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
8689e4c7
Commit
8689e4c7
authored
Aug 15, 2025
by
helloyongyang
Browse files
update tea cache
parent
0b755a97
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
25 additions
and
33 deletions
+25
-33
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+1
-1
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+2
-2
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+10
-16
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+2
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+0
-5
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+9
-7
lightx2v/models/schedulers/scheduler.py
lightx2v/models/schedulers/scheduler.py
+1
-0
No files found.
lightx2v/models/networks/wan/causvid_model.py
View file @
8689e4c7
...
...
@@ -54,7 +54,7 @@ class WanCausVidModel(WanModel):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
,
kv_start
=
kv_start
,
kv_end
=
kv_end
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
kv_start
=
kv_start
,
kv_end
=
kv_end
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
,
kv_start
,
kv_end
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
8689e4c7
...
...
@@ -35,7 +35,7 @@ class WanAudioPreInfer(WanPreInfer):
else
:
self
.
sp_size
=
1
def
infer
(
self
,
weights
,
inputs
,
positive
):
def
infer
(
self
,
weights
,
inputs
):
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
hidden_states
=
self
.
scheduler
.
latents
...
...
@@ -71,7 +71,7 @@ class WanAudioPreInfer(WanPreInfer):
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
# audio_dit_blocks = None##Debug Drop Audio
if
pos
iti
ve
:
if
self
.
scheduler
.
infer_cond
iti
on
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
...
...
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
8689e4c7
...
...
@@ -24,7 +24,6 @@ class WanTransformerInferCaching(WanTransformerInfer):
class
WanTransformerInferTeaCaching
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
...
...
@@ -35,12 +34,12 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self
.
use_ret_steps
=
config
.
use_ret_steps
if
self
.
use_ret_steps
:
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
self
.
ret_steps
=
5
self
.
cutoff_steps
=
self
.
config
.
infer_steps
else
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
self
.
ret_steps
=
1
self
.
cutoff_steps
=
self
.
config
.
infer_steps
-
1
# calculate should_calc
@
torch
.
no_grad
()
...
...
@@ -50,8 +49,8 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
# 2. L1 calculate
should_calc
=
False
if
self
.
infer_condition
al
:
if
self
.
cnt
<
self
.
ret_steps
or
self
.
cnt
>=
self
.
cutoff_steps
:
if
self
.
scheduler
.
infer_condition
:
if
self
.
scheduler
.
step_index
<
self
.
ret_steps
or
self
.
scheduler
.
step_index
>=
self
.
cutoff_steps
:
should_calc
=
True
self
.
accumulated_rel_l1_distance_even
=
0
else
:
...
...
@@ -67,7 +66,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self
.
previous_e0_even
=
self
.
previous_e0_even
.
cpu
()
else
:
if
self
.
cnt
<
self
.
ret_steps
or
self
.
cnt
>=
self
.
cutoff_steps
:
if
self
.
scheduler
.
step_index
<
self
.
ret_steps
or
self
.
scheduler
.
step_index
>=
self
.
cutoff_steps
:
should_calc
=
True
self
.
accumulated_rel_l1_distance_odd
=
0
else
:
...
...
@@ -97,7 +96,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return
should_calc
def
infer_main_blocks
(
self
,
weights
,
pre_infer_out
):
if
self
.
infer_condition
al
:
if
self
.
scheduler
.
infer_condition
:
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
...
...
@@ -121,11 +120,6 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
else
:
x
=
self
.
infer_using_cache
(
pre_infer_out
.
x
)
if
self
.
config
.
enable_cfg
:
self
.
switch_status
()
self
.
cnt
+=
1
if
self
.
clean_cuda_cache
:
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
torch
.
cuda
.
empty_cache
()
...
...
@@ -136,7 +130,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
ori_x
=
pre_infer_out
.
x
.
clone
()
x
=
super
().
infer_main_blocks
(
weights
,
pre_infer_out
)
if
self
.
infer_condition
al
:
if
self
.
scheduler
.
infer_condition
:
self
.
previous_residual_even
=
x
-
ori_x
if
self
.
config
[
"cpu_offload"
]:
self
.
previous_residual_even
=
self
.
previous_residual_even
.
cpu
()
...
...
@@ -153,7 +147,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return
x
def
infer_using_cache
(
self
,
x
):
if
self
.
infer_condition
al
:
if
self
.
scheduler
.
infer_condition
:
x
.
add_
(
self
.
previous_residual_even
.
cuda
())
else
:
x
.
add_
(
self
.
previous_residual_odd
.
cuda
())
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
8689e4c7
...
...
@@ -33,7 +33,7 @@ class WanPreInfer:
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
inputs
,
positive
,
kv_start
=
0
,
kv_end
=
0
):
def
infer
(
self
,
weights
,
inputs
,
kv_start
=
0
,
kv_end
=
0
):
x
=
self
.
scheduler
.
latents
if
self
.
scheduler
.
flag_df
:
...
...
@@ -45,7 +45,7 @@ class WanPreInfer:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
t
=
(
self
.
scheduler
.
mask
[
0
][:,
::
2
,
::
2
]
*
t
).
flatten
()
if
pos
iti
ve
:
if
self
.
scheduler
.
infer_cond
iti
on
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
8689e4c7
...
...
@@ -78,11 +78,6 @@ class WanTransformerInfer(BaseTransformerInfer):
else
:
self
.
infer_func
=
self
.
_infer_without_offload
self
.
infer_conditional
=
True
def
switch_status
(
self
):
self
.
infer_conditional
=
not
self
.
infer_conditional
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
...
...
lightx2v/models/networks/wan/model.py
View file @
8689e4c7
...
...
@@ -329,9 +329,9 @@ class WanModel:
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
pos
iti
ve
=
True
)
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_cond
iti
on
=
True
)
else
:
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
pos
iti
ve
=
False
)
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_cond
iti
on
=
False
)
noise_pred_list
=
[
torch
.
zeros_like
(
noise_pred
)
for
_
in
range
(
2
)]
dist
.
all_gather
(
noise_pred_list
,
noise_pred
,
group
=
cfg_p_group
)
...
...
@@ -339,13 +339,13 @@ class WanModel:
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
else
:
# ==================== CFG Processing ====================
noise_pred_cond
=
self
.
_infer_cond_uncond
(
inputs
,
pos
iti
ve
=
True
)
noise_pred_uncond
=
self
.
_infer_cond_uncond
(
inputs
,
pos
iti
ve
=
False
)
noise_pred_cond
=
self
.
_infer_cond_uncond
(
inputs
,
infer_cond
iti
on
=
True
)
noise_pred_uncond
=
self
.
_infer_cond_uncond
(
inputs
,
infer_cond
iti
on
=
False
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
else
:
# ==================== No CFG ====================
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
pos
iti
ve
=
True
)
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_cond
iti
on
=
True
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
...
...
@@ -355,8 +355,10 @@ class WanModel:
self
.
transformer_weights
.
post_weights_to_cpu
()
@
torch
.
no_grad
()
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
positive
)
def
_infer_cond_uncond
(
self
,
inputs
,
infer_condition
=
True
):
self
.
scheduler
.
infer_condition
=
infer_condition
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
)
if
self
.
config
[
"seq_parallel"
]:
pre_infer_out
=
self
.
_seq_parallel_pre_process
(
pre_infer_out
)
...
...
lightx2v/models/schedulers/scheduler.py
View file @
8689e4c7
...
...
@@ -10,6 +10,7 @@ class BaseScheduler:
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
infer_condition
=
True
# cfg status
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment