Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
dcaefe63
Commit
dcaefe63
authored
Jun 29, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Jun 29, 2025
Browse files
update feature caching (#78)
Co-authored-by:
Linboyan-trc
<
1584340372@qq.com
>
parent
bff9bd05
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
840 additions
and
137 deletions
+840
-137
lightx2v/common/transformer_infer/transformer_infer.py
lightx2v/common/transformer_infer/transformer_infer.py
+47
-0
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+734
-92
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+6
-3
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+9
-9
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+9
-0
lightx2v/models/schedulers/scheduler.py
lightx2v/models/schedulers/scheduler.py
+3
-0
lightx2v/models/schedulers/wan/feature_caching/scheduler.py
lightx2v/models/schedulers/wan/feature_caching/scheduler.py
+30
-33
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+2
-0
No files found.
lightx2v/common/transformer_infer/transformer_infer.py
0 → 100644
View file @
dcaefe63
from
abc
import
ABC
,
abstractmethod
import
torch
import
math
class
BaseTransformerInfer
(
ABC
):
@
abstractmethod
def
infer
(
self
):
pass
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
.
transformer_infer
=
self
class
BaseTaylorCachingTransformerInfer
(
BaseTransformerInfer
):
@
abstractmethod
def
infer_calculating
(
self
):
pass
@
abstractmethod
def
infer_using_cache
(
self
):
pass
@
abstractmethod
def
get_taylor_step_diff
(
self
):
pass
# 1. when fully calcualted, stored in cache
def
derivative_approximation
(
self
,
block_cache
,
module_name
,
out
):
if
module_name
not
in
block_cache
:
block_cache
[
module_name
]
=
{
0
:
out
}
else
:
step_diff
=
self
.
get_taylor_step_diff
()
previous_out
=
block_cache
[
module_name
][
0
]
block_cache
[
module_name
][
0
]
=
out
block_cache
[
module_name
][
1
]
=
(
out
-
previous_out
)
/
step_diff
def
taylor_formula
(
self
,
tensor_dict
):
x
=
self
.
get_taylor_step_diff
()
output
=
0
for
i
in
range
(
len
(
tensor_dict
)):
output
+=
(
1
/
math
.
factorial
(
i
))
*
tensor_dict
[
i
]
*
(
x
**
i
)
return
output
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
dcaefe63
This diff is collapsed.
Click to expand it.
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
dcaefe63
...
@@ -4,10 +4,11 @@ from lightx2v.common.offload.manager import (
...
@@ -4,10 +4,11 @@ from lightx2v.common.offload.manager import (
WeightAsyncStreamManager
,
WeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
)
)
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
class
WanTransformerInfer
:
class
WanTransformerInfer
(
BaseTransformerInfer
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
=
config
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
...
@@ -49,8 +50,10 @@ class WanTransformerInfer:
...
@@ -49,8 +50,10 @@ class WanTransformerInfer:
else
:
else
:
self
.
infer_func
=
self
.
_infer_without_offload
self
.
infer_func
=
self
.
_infer_without_offload
def
set_scheduler
(
self
,
scheduler
):
self
.
infer_conditional
=
True
self
.
scheduler
=
scheduler
def
switch_status
(
self
):
self
.
infer_conditional
=
not
self
.
infer_conditional
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
...
...
lightx2v/models/networks/wan/model.py
View file @
dcaefe63
...
@@ -15,6 +15,9 @@ from lightx2v.models.networks.wan.infer.transformer_infer import (
...
@@ -15,6 +15,9 @@ from lightx2v.models.networks.wan.infer.transformer_infer import (
)
)
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
WanTransformerInferTeaCaching
,
WanTransformerInferTeaCaching
,
WanTransformerInferTaylorCaching
,
WanTransformerInferAdaCaching
,
WanTransformerInferCustomCaching
,
)
)
from
safetensors
import
safe_open
from
safetensors
import
safe_open
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
...
@@ -59,6 +62,12 @@ class WanModel:
...
@@ -59,6 +62,12 @@ class WanModel:
self
.
transformer_infer_class
=
WanTransformerInfer
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Taylor"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
WanTransformerInferCustomCaching
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
...
@@ -201,10 +210,6 @@ class WanModel:
...
@@ -201,10 +210,6 @@ class WanModel:
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
if
self
.
config
[
"enable_cfg"
]:
...
@@ -212,11 +217,6 @@ class WanModel:
...
@@ -212,11 +217,6 @@ class WanModel:
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
dcaefe63
...
@@ -9,6 +9,9 @@ from lightx2v.models.runners.default_runner import DefaultRunner
...
@@ -9,6 +9,9 @@ from lightx2v.models.runners.default_runner import DefaultRunner
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.feature_caching.scheduler
import
(
from
lightx2v.models.schedulers.wan.feature_caching.scheduler
import
(
WanSchedulerTeaCaching
,
WanSchedulerTeaCaching
,
WanSchedulerTaylorCaching
,
WanSchedulerAdaCaching
,
WanSchedulerCustomCaching
,
)
)
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
...
@@ -114,6 +117,12 @@ class WanRunner(DefaultRunner):
...
@@ -114,6 +117,12 @@ class WanRunner(DefaultRunner):
scheduler
=
WanScheduler
(
self
.
config
)
scheduler
=
WanScheduler
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Tea"
:
elif
self
.
config
.
feature_caching
==
"Tea"
:
scheduler
=
WanSchedulerTeaCaching
(
self
.
config
)
scheduler
=
WanSchedulerTeaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Taylor"
:
scheduler
=
WanSchedulerTaylorCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Ada"
:
scheduler
=
WanSchedulerAdaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Custom"
:
scheduler
=
WanSchedulerCustomCaching
(
self
.
config
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
self
.
model
.
set_scheduler
(
scheduler
)
self
.
model
.
set_scheduler
(
scheduler
)
...
...
lightx2v/models/schedulers/scheduler.py
View file @
dcaefe63
...
@@ -7,7 +7,10 @@ class BaseScheduler:
...
@@ -7,7 +7,10 @@ class BaseScheduler:
self
.
config
=
config
self
.
config
=
config
self
.
step_index
=
0
self
.
step_index
=
0
self
.
latents
=
None
self
.
latents
=
None
self
.
infer_steps
=
config
.
infer_steps
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
flag_df
=
False
self
.
transformer_infer
=
None
def
step_pre
(
self
,
step_index
):
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
self
.
step_index
=
step_index
...
...
lightx2v/models/schedulers/wan/feature_caching/scheduler.py
View file @
dcaefe63
import
torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
..scheduler
import
WanScheduler
class
WanSchedulerTeaCaching
(
WanScheduler
):
class
WanSchedulerTeaCaching
(
WanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
num_steps
=
self
.
config
.
infer_steps
*
2
self
.
teacache_thresh
=
self
.
config
.
teacache_thresh
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
previous_e0_even
=
None
self
.
previous_e0_odd
=
None
self
.
previous_residual_even
=
None
self
.
previous_residual_odd
=
None
self
.
use_ret_steps
=
self
.
config
.
use_ret_steps
if
self
.
use_ret_steps
:
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
else
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
def
clear
(
self
):
def
clear
(
self
):
if
self
.
previous_e0_even
is
not
None
:
self
.
transformer_infer
.
clear
()
self
.
previous_e0_even
=
self
.
previous_e0_even
.
cpu
()
if
self
.
previous_e0_odd
is
not
None
:
self
.
previous_e0_odd
=
self
.
previous_e0_odd
.
cpu
()
class
WanSchedulerTaylorCaching
(
WanScheduler
):
if
self
.
previous_residual_even
is
not
None
:
def
__init__
(
self
,
config
):
self
.
previous_residual_even
=
self
.
previous_residual_even
.
cpu
()
super
().
__init__
(
config
)
if
self
.
previous_residual_odd
is
not
None
:
self
.
previous_residual_odd
=
self
.
previous_residual_odd
.
cpu
()
pattern
=
[
True
,
False
,
False
,
False
]
self
.
previous_e0_even
=
None
self
.
caching_records
=
(
pattern
*
((
config
.
infer_steps
+
3
)
//
4
))[:
config
.
infer_steps
]
self
.
previous_e0_odd
=
None
self
.
caching_records_2
=
(
pattern
*
((
config
.
infer_steps
+
3
)
//
4
))[:
config
.
infer_steps
]
self
.
previous_residual_even
=
None
self
.
previous_residual_odd
=
None
def
clear
(
self
):
torch
.
cuda
.
empty_cache
()
self
.
transformer_infer
.
clear
()
class
WanSchedulerAdaCaching
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
WanSchedulerCustomCaching
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
lightx2v/models/schedulers/wan/scheduler.py
View file @
dcaefe63
...
@@ -18,6 +18,8 @@ class WanScheduler(BaseScheduler):
...
@@ -18,6 +18,8 @@ class WanScheduler(BaseScheduler):
self
.
solver_order
=
2
self
.
solver_order
=
2
self
.
noise_pred
=
None
self
.
noise_pred
=
None
self
.
caching_records_2
=
[
True
]
*
self
.
config
.
infer_steps
def
prepare
(
self
,
image_encoder_output
=
None
):
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment