Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
220a631f
Commit
220a631f
authored
Jun 30, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Jun 30, 2025
Browse files
update hunyuan cache (#79)
Co-authored-by:
Linboyan-trc
<
1584340372@qq.com
>
parent
9da774a7
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
574 additions
and
277 deletions
+574
-277
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
...tworks/hunyuan/infer/feature_caching/transformer_infer.py
+524
-244
lightx2v/models/networks/hunyuan/model.py
lightx2v/models/networks/hunyuan/model.py
+10
-6
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+12
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+1
-1
lightx2v/models/runners/hunyuan/hunyuan_runner.py
lightx2v/models/runners/hunyuan/hunyuan_runner.py
+5
-1
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-1
lightx2v/models/schedulers/hunyuan/feature_caching/scheduler.py
...2v/models/schedulers/hunyuan/feature_caching/scheduler.py
+21
-21
lightx2v/models/schedulers/hunyuan/scheduler.py
lightx2v/models/schedulers/hunyuan/scheduler.py
+0
-1
No files found.
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
View file @
220a631f
This diff is collapsed.
Click to expand it.
lightx2v/models/networks/hunyuan/model.py
View file @
220a631f
...
@@ -7,8 +7,12 @@ from lightx2v.models.networks.hunyuan.weights.transformer_weights import Hunyuan
...
@@ -7,8 +7,12 @@ from lightx2v.models.networks.hunyuan.weights.transformer_weights import Hunyuan
from
lightx2v.models.networks.hunyuan.infer.pre_infer
import
HunyuanPreInfer
from
lightx2v.models.networks.hunyuan.infer.pre_infer
import
HunyuanPreInfer
from
lightx2v.models.networks.hunyuan.infer.post_infer
import
HunyuanPostInfer
from
lightx2v.models.networks.hunyuan.infer.post_infer
import
HunyuanPostInfer
from
lightx2v.models.networks.hunyuan.infer.transformer_infer
import
HunyuanTransformerInfer
from
lightx2v.models.networks.hunyuan.infer.transformer_infer
import
HunyuanTransformerInfer
from
lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer
import
HunyuanTransformerInferTaylorCaching
,
HunyuanTransformerInferTeaCaching
from
lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer
import
(
HunyuanTransformerInferTaylorCaching
,
HunyuanTransformerInferTeaCaching
,
HunyuanTransformerInferAdaCaching
,
HunyuanTransformerInferCustomCaching
,
)
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -156,10 +160,6 @@ class HunyuanModel:
...
@@ -156,10 +160,6 @@ class HunyuanModel:
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
==
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
def
_init_infer_class
(
self
):
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
HunyuanPreInfer
self
.
pre_infer_class
=
HunyuanPreInfer
...
@@ -170,5 +170,9 @@ class HunyuanModel:
...
@@ -170,5 +170,9 @@ class HunyuanModel:
self
.
transformer_infer_class
=
HunyuanTransformerInferTaylorCaching
self
.
transformer_infer_class
=
HunyuanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferTeaCaching
self
.
transformer_infer_class
=
HunyuanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferCustomCaching
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
220a631f
...
@@ -305,13 +305,23 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -305,13 +305,23 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
for
cache
in
self
.
blocks_cache_even
:
for
cache
in
self
.
blocks_cache_even
:
for
key
in
cache
:
for
key
in
cache
:
if
cache
[
key
]
is
not
None
:
if
cache
[
key
]
is
not
None
:
cache
[
key
]
=
cache
[
key
].
cpu
()
if
isinstance
(
cache
[
key
],
torch
.
Tensor
):
cache
[
key
]
=
cache
[
key
].
cpu
()
elif
isinstance
(
cache
[
key
],
dict
):
for
k
,
v
in
cache
[
key
].
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
cache
[
key
][
k
]
=
v
.
cpu
()
cache
.
clear
()
cache
.
clear
()
for
cache
in
self
.
blocks_cache_odd
:
for
cache
in
self
.
blocks_cache_odd
:
for
key
in
cache
:
for
key
in
cache
:
if
cache
[
key
]
is
not
None
:
if
cache
[
key
]
is
not
None
:
cache
[
key
]
=
cache
[
key
].
cpu
()
if
isinstance
(
cache
[
key
],
torch
.
Tensor
):
cache
[
key
]
=
cache
[
key
].
cpu
()
elif
isinstance
(
cache
[
key
],
dict
):
for
k
,
v
in
cache
[
key
].
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
cache
[
key
][
k
]
=
v
.
cpu
()
cache
.
clear
()
cache
.
clear
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/networks/wan/model.py
View file @
220a631f
...
@@ -62,7 +62,7 @@ class WanModel:
...
@@ -62,7 +62,7 @@ class WanModel:
self
.
transformer_infer_class
=
WanTransformerInfer
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Taylor"
:
elif
self
.
config
[
"feature_caching"
]
==
"Taylor
Seer
"
:
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
self
.
transformer_infer_class
=
WanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
self
.
transformer_infer_class
=
WanTransformerInferAdaCaching
...
...
lightx2v/models/runners/hunyuan/hunyuan_runner.py
View file @
220a631f
...
@@ -6,7 +6,7 @@ from PIL import Image
...
@@ -6,7 +6,7 @@ from PIL import Image
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.hunyuan.scheduler
import
HunyuanScheduler
from
lightx2v.models.schedulers.hunyuan.scheduler
import
HunyuanScheduler
from
lightx2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerTaylorCaching
,
HunyuanSchedulerTeaCaching
from
lightx2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerTaylorCaching
,
HunyuanSchedulerTeaCaching
,
HunyuanSchedulerAdaCaching
,
HunyuanSchedulerCustomCaching
from
lightx2v.models.input_encoders.hf.llama.model
import
TextEncoderHFLlamaModel
from
lightx2v.models.input_encoders.hf.llama.model
import
TextEncoderHFLlamaModel
from
lightx2v.models.input_encoders.hf.clip.model
import
TextEncoderHFClipModel
from
lightx2v.models.input_encoders.hf.clip.model
import
TextEncoderHFClipModel
from
lightx2v.models.input_encoders.hf.llava.model
import
TextEncoderHFLlavaModel
from
lightx2v.models.input_encoders.hf.llava.model
import
TextEncoderHFLlavaModel
...
@@ -47,6 +47,10 @@ class HunyuanRunner(DefaultRunner):
...
@@ -47,6 +47,10 @@ class HunyuanRunner(DefaultRunner):
scheduler
=
HunyuanSchedulerTeaCaching
(
self
.
config
)
scheduler
=
HunyuanSchedulerTeaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
scheduler
=
HunyuanSchedulerTaylorCaching
(
self
.
config
)
scheduler
=
HunyuanSchedulerTaylorCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Ada"
:
scheduler
=
HunyuanSchedulerAdaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Custom"
:
scheduler
=
HunyuanSchedulerCustomCaching
(
self
.
config
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
self
.
model
.
set_scheduler
(
scheduler
)
self
.
model
.
set_scheduler
(
scheduler
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
220a631f
...
@@ -117,7 +117,7 @@ class WanRunner(DefaultRunner):
...
@@ -117,7 +117,7 @@ class WanRunner(DefaultRunner):
scheduler
=
WanScheduler
(
self
.
config
)
scheduler
=
WanScheduler
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Tea"
:
elif
self
.
config
.
feature_caching
==
"Tea"
:
scheduler
=
WanSchedulerTeaCaching
(
self
.
config
)
scheduler
=
WanSchedulerTeaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Taylor"
:
elif
self
.
config
.
feature_caching
==
"Taylor
Seer
"
:
scheduler
=
WanSchedulerTaylorCaching
(
self
.
config
)
scheduler
=
WanSchedulerTaylorCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Ada"
:
elif
self
.
config
.
feature_caching
==
"Ada"
:
scheduler
=
WanSchedulerAdaCaching
(
self
.
config
)
scheduler
=
WanSchedulerAdaCaching
(
self
.
config
)
...
...
lightx2v/models/schedulers/hunyuan/feature_caching/scheduler.py
View file @
220a631f
from
.utils
import
cache_init
,
cal_type
from
..scheduler
import
HunyuanScheduler
from
..scheduler
import
HunyuanScheduler
import
torch
import
torch
...
@@ -6,31 +5,32 @@ import torch
...
@@ -6,31 +5,32 @@ import torch
class
HunyuanSchedulerTeaCaching
(
HunyuanScheduler
):
class
HunyuanSchedulerTeaCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
num_steps
=
self
.
config
.
infer_steps
self
.
teacache_thresh
=
self
.
config
.
teacache_thresh
self
.
accumulated_rel_l1_distance
=
0
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
self
.
coefficients
=
[
7.33226126e02
,
-
4.01131952e02
,
6.75869174e01
,
-
3.14987800e00
,
9.61237896e-02
]
def
clear
(
self
):
def
clear
(
self
):
if
self
.
previous_residual
is
not
None
:
self
.
transformer_infer
.
clear
()
self
.
previous_residual
=
self
.
previous_residual
.
cpu
()
if
self
.
previous_modulated_input
is
not
None
:
self
.
previous_modulated_input
=
self
.
previous_modulated_input
.
cpu
()
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
torch
.
cuda
.
empty_cache
()
class
HunyuanSchedulerTaylorCaching
(
HunyuanScheduler
):
class
HunyuanSchedulerTaylorCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
cache_dic
,
self
.
current
=
cache_init
(
self
.
infer_steps
)
pattern
=
[
True
,
False
,
False
,
False
]
self
.
caching_records
=
(
pattern
*
((
config
.
infer_steps
+
3
)
//
4
))[:
config
.
infer_steps
]
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
HunyuanSchedulerAdaCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
step_pre
(
self
,
step_index
):
def
clear
(
self
):
super
().
step_pre
(
step_index
)
self
.
transformer_infer
.
clear
()
self
.
current
[
"step"
]
=
step_index
cal_type
(
self
.
cache_dic
,
self
.
current
)
class
HunyuanSchedulerCustomCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
lightx2v/models/schedulers/hunyuan/scheduler.py
View file @
220a631f
...
@@ -237,7 +237,6 @@ def get_1d_rotary_pos_embed_riflex(
...
@@ -237,7 +237,6 @@ def get_1d_rotary_pos_embed_riflex(
class
HunyuanScheduler
(
BaseScheduler
):
class
HunyuanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
shift
=
7.0
self
.
shift
=
7.0
self
.
timesteps
,
self
.
sigmas
=
set_timesteps_sigmas
(
self
.
infer_steps
,
self
.
shift
,
device
=
torch
.
device
(
"cuda"
))
self
.
timesteps
,
self
.
sigmas
=
set_timesteps_sigmas
(
self
.
infer_steps
,
self
.
shift
,
device
=
torch
.
device
(
"cuda"
))
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment