Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
40381d5a
Commit
40381d5a
authored
Jul 28, 2025
by
helloyongyang
Browse files
support cache with changing res
parent
20525490
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
94 deletions
+85
-94
configs/changing_resolution/wan_t2v_U_teacache.json
configs/changing_resolution/wan_t2v_U_teacache.json
+25
-0
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+34
-21
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+13
-26
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+11
-2
lightx2v/models/schedulers/wan/feature_caching/scheduler.py
lightx2v/models/schedulers/wan/feature_caching/scheduler.py
+2
-45
No files found.
configs/changing_resolution/wan_t2v_U_teacache.json
0 → 100755
View file @
40381d5a
{
"infer_steps"
:
50
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"changing_resolution"
:
true
,
"resolution_rate"
:
[
1.0
,
0.75
],
"changing_resolution_steps"
:
[
10
,
35
],
"feature_caching"
:
"Tea"
,
"coefficients"
:
[
[
-5.21862437e04
,
9.23041404e03
,
-5.28275948e02
,
1.36987616e01
,
-4.99875664e-02
],
[
2.39676752e03
,
-1.31110545e03
,
2.01331979e02
,
-8.29855975e00
,
1.37887774e-01
]
],
"use_ret_steps"
:
false
,
"teacache_thresh"
:
0.1
}
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
40381d5a
...
@@ -5,7 +5,20 @@ import numpy as np
...
@@ -5,7 +5,20 @@ import numpy as np
import
gc
import
gc
class
WanTransformerInferTeaCaching
(
WanTransformerInfer
):
class
WanTransformerInferCaching
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
must_calc_steps
=
[]
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
self
.
must_calc_steps
=
self
.
config
[
"changing_resolution_steps"
]
def
must_calc
(
self
,
step_index
):
if
step_index
in
self
.
must_calc_steps
:
return
True
return
False
class
WanTransformerInferTeaCaching
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
cnt
=
0
...
@@ -87,7 +100,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
...
@@ -87,7 +100,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
if
caching_records
[
index
]:
if
caching_records
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
...
@@ -99,7 +112,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
...
@@ -99,7 +112,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
if
caching_records_2
[
index
]:
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
...
@@ -169,7 +182,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
...
@@ -169,7 +182,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
class
WanTransformerInferTaylorCaching
(
WanTransformerInfer
,
BaseTaylorCachingTransformerInfer
):
class
WanTransformerInferTaylorCaching
(
WanTransformerInfer
Caching
,
BaseTaylorCachingTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -199,7 +212,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -199,7 +212,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
caching_records
=
self
.
scheduler
.
caching_records
if
caching_records
[
index
]:
if
caching_records
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
...
@@ -208,7 +221,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -208,7 +221,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records_2
=
self
.
scheduler
.
caching_records_2
caching_records_2
=
self
.
scheduler
.
caching_records_2
if
caching_records_2
[
index
]:
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
...
@@ -305,7 +318,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -305,7 +318,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
class
WanTransformerInferAdaCaching
(
WanTransformerInfer
):
class
WanTransformerInferAdaCaching
(
WanTransformerInfer
Caching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -322,7 +335,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
...
@@ -322,7 +335,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
caching_records
=
self
.
scheduler
.
caching_records
if
caching_records
[
index
]:
if
caching_records
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
# 1. calculate the skipped step length
# 1. calculate the skipped step length
...
@@ -338,7 +351,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
...
@@ -338,7 +351,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records_2
caching_records
=
self
.
scheduler
.
caching_records_2
if
caching_records
[
index
]:
if
caching_records
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
# 1. calculate the skipped step length
# 1. calculate the skipped step length
...
@@ -518,7 +531,7 @@ class AdaArgs:
...
@@ -518,7 +531,7 @@ class AdaArgs:
self
.
spatial_dim
=
1536
self
.
spatial_dim
=
1536
class
WanTransformerInferCustomCaching
(
WanTransformerInfer
,
BaseTaylorCachingTransformerInfer
):
class
WanTransformerInferCustomCaching
(
WanTransformerInfer
Caching
,
BaseTaylorCachingTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
cnt
=
0
...
@@ -605,7 +618,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -605,7 +618,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
should_calc
=
self
.
calculate_should_calc
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
should_calc
=
self
.
calculate_should_calc
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
if
caching_records
[
index
]:
if
caching_records
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
...
@@ -617,7 +630,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -617,7 +630,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
should_calc
=
self
.
calculate_should_calc
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
should_calc
=
self
.
calculate_should_calc
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
if
caching_records_2
[
index
]:
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
...
@@ -683,7 +696,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -683,7 +696,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
class
WanTransformerInferFirstBlock
(
WanTransformerInfer
):
class
WanTransformerInferFirstBlock
(
WanTransformerInfer
Caching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -707,7 +720,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
...
@@ -707,7 +720,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
should_calc
=
self
.
calculate_should_calc
(
x_residual
)
should_calc
=
self
.
calculate_should_calc
(
x_residual
)
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
if
caching_records
[
index
]:
if
caching_records
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
...
@@ -719,7 +732,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
...
@@ -719,7 +732,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
should_calc
=
self
.
calculate_should_calc
(
x_residual
)
should_calc
=
self
.
calculate_should_calc
(
x_residual
)
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
if
caching_records_2
[
index
]:
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
...
@@ -788,7 +801,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
...
@@ -788,7 +801,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
class
WanTransformerInferDualBlock
(
WanTransformerInfer
):
class
WanTransformerInferDualBlock
(
WanTransformerInfer
Caching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -822,7 +835,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
...
@@ -822,7 +835,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
should_calc
=
self
.
calculate_should_calc
(
x_residual
)
should_calc
=
self
.
calculate_should_calc
(
x_residual
)
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
if
caching_records
[
index
]:
if
caching_records
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
...
@@ -834,7 +847,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
...
@@ -834,7 +847,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
should_calc
=
self
.
calculate_should_calc
(
x_residual
)
should_calc
=
self
.
calculate_should_calc
(
x_residual
)
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
if
caching_records_2
[
index
]:
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
)
:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
...
@@ -915,7 +928,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
...
@@ -915,7 +928,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
class
WanTransformerInferDynamicBlock
(
WanTransformerInfer
):
class
WanTransformerInferDynamicBlock
(
WanTransformerInfer
Caching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
...
@@ -938,7 +951,7 @@ class WanTransformerInferDynamicBlock(WanTransformerInfer):
...
@@ -938,7 +951,7 @@ class WanTransformerInferDynamicBlock(WanTransformerInfer):
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
if
self
.
block_in_cache_even
[
block_idx
]
is
not
None
:
if
self
.
block_in_cache_even
[
block_idx
]
is
not
None
:
should_calc
=
self
.
are_two_tensor_similar
(
self
.
block_in_cache_even
[
block_idx
],
x
)
should_calc
=
self
.
are_two_tensor_similar
(
self
.
block_in_cache_even
[
block_idx
],
x
)
if
should_calc
:
if
should_calc
or
self
.
must_calc
(
block_idx
)
:
x
=
super
().
infer_block
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
super
().
infer_block
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
+=
self
.
block_residual_cache_even
[
block_idx
]
x
+=
self
.
block_residual_cache_even
[
block_idx
]
...
@@ -953,7 +966,7 @@ class WanTransformerInferDynamicBlock(WanTransformerInfer):
...
@@ -953,7 +966,7 @@ class WanTransformerInferDynamicBlock(WanTransformerInfer):
else
:
else
:
if
self
.
block_in_cache_odd
[
block_idx
]
is
not
None
:
if
self
.
block_in_cache_odd
[
block_idx
]
is
not
None
:
should_calc
=
self
.
are_two_tensor_similar
(
self
.
block_in_cache_odd
[
block_idx
],
x
)
should_calc
=
self
.
are_two_tensor_similar
(
self
.
block_in_cache_odd
[
block_idx
],
x
)
if
should_calc
:
if
should_calc
or
self
.
must_calc
(
block_idx
)
:
x
=
super
().
infer_block
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
super
().
infer_block
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
+=
self
.
block_residual_cache_odd
[
block_idx
]
x
+=
self
.
block_residual_cache_odd
[
block_idx
]
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
40381d5a
...
@@ -8,16 +8,11 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
...
@@ -8,16 +8,11 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.changing_resolution.scheduler
import
(
from
lightx2v.models.schedulers.wan.changing_resolution.scheduler
import
(
WanScheduler4ChangingResolution
,
WanScheduler4ChangingResolution
Interface
,
)
)
from
lightx2v.models.schedulers.wan.feature_caching.scheduler
import
(
from
lightx2v.models.schedulers.wan.feature_caching.scheduler
import
(
WanScheduler
Tea
Caching
,
WanSchedulerCaching
,
WanSchedulerTaylorCaching
,
WanSchedulerTaylorCaching
,
WanSchedulerAdaCaching
,
WanSchedulerCustomCaching
,
WanSchedulerFirstBlock
,
WanSchedulerDualBlock
,
WanSchedulerDynamicBlock
,
)
)
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
...
@@ -159,27 +154,19 @@ class WanRunner(DefaultRunner):
...
@@ -159,27 +154,19 @@ class WanRunner(DefaultRunner):
return
vae_encoder
,
vae_decoder
return
vae_encoder
,
vae_decoder
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
if
self
.
config
.
feature_caching
==
"NoCaching"
:
scheduler_class
=
WanScheduler
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
scheduler_class
=
WanSchedulerTaylorCaching
elif
self
.
config
.
feature_caching
in
[
"Tea"
,
"Ada"
,
"Custom"
,
"FirstBlock"
,
"DualBlock"
,
"DynamicBlock"
]:
scheduler_class
=
WanSchedulerCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
scheduler
=
WanScheduler4ChangingResolution
(
self
.
config
)
scheduler
=
WanScheduler4ChangingResolution
Interface
(
scheduler_class
,
self
.
config
)
else
:
else
:
if
self
.
config
.
feature_caching
==
"NoCaching"
:
scheduler
=
scheduler_class
(
self
.
config
)
scheduler
=
WanScheduler
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Tea"
:
scheduler
=
WanSchedulerTeaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
scheduler
=
WanSchedulerTaylorCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Ada"
:
scheduler
=
WanSchedulerAdaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Custom"
:
scheduler
=
WanSchedulerCustomCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"FirstBlock"
:
scheduler
=
WanSchedulerFirstBlock
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"DualBlock"
:
scheduler
=
WanSchedulerDualBlock
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"DynamicBlock"
:
scheduler
=
WanSchedulerDynamicBlock
(
self
.
config
)
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
self
.
model
.
set_scheduler
(
scheduler
)
self
.
model
.
set_scheduler
(
scheduler
)
def
run_text_encoder
(
self
,
text
,
img
):
def
run_text_encoder
(
self
,
text
,
img
):
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
40381d5a
...
@@ -2,9 +2,18 @@ import torch
...
@@ -2,9 +2,18 @@ import torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
class
WanScheduler4ChangingResolution
(
WanScheduler
):
class
WanScheduler4ChangingResolutionInterface
:
def
__new__
(
cls
,
father_scheduler
,
config
):
class
NewClass
(
WanScheduler4ChangingResolution
,
father_scheduler
):
def
__init__
(
self
,
config
):
father_scheduler
.
__init__
(
self
,
config
)
WanScheduler4ChangingResolution
.
__init__
(
self
,
config
)
return
NewClass
(
config
)
class
WanScheduler4ChangingResolution
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
if
"resolution_rate"
not
in
config
:
if
"resolution_rate"
not
in
config
:
config
[
"resolution_rate"
]
=
[
0.75
]
config
[
"resolution_rate"
]
=
[
0.75
]
if
"changing_resolution_steps"
not
in
config
:
if
"changing_resolution_steps"
not
in
config
:
...
...
lightx2v/models/schedulers/wan/feature_caching/scheduler.py
View file @
40381d5a
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
class
WanScheduler
Tea
Caching
(
WanScheduler
):
class
WanSchedulerCaching
(
WanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -9,53 +9,10 @@ class WanSchedulerTeaCaching(WanScheduler):
...
@@ -9,53 +9,10 @@ class WanSchedulerTeaCaching(WanScheduler):
self
.
transformer_infer
.
clear
()
self
.
transformer_infer
.
clear
()
class
WanSchedulerTaylorCaching
(
WanScheduler
):
class
WanSchedulerTaylorCaching
(
WanScheduler
Caching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
pattern
=
[
True
,
False
,
False
,
False
]
pattern
=
[
True
,
False
,
False
,
False
]
self
.
caching_records
=
(
pattern
*
((
config
.
infer_steps
+
3
)
//
4
))[:
config
.
infer_steps
]
self
.
caching_records
=
(
pattern
*
((
config
.
infer_steps
+
3
)
//
4
))[:
config
.
infer_steps
]
self
.
caching_records_2
=
(
pattern
*
((
config
.
infer_steps
+
3
)
//
4
))[:
config
.
infer_steps
]
self
.
caching_records_2
=
(
pattern
*
((
config
.
infer_steps
+
3
)
//
4
))[:
config
.
infer_steps
]
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
WanSchedulerAdaCaching
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
WanSchedulerCustomCaching
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
WanSchedulerFirstBlock
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
WanSchedulerDualBlock
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
class
WanSchedulerDynamicBlock
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
clear
(
self
):
self
.
transformer_infer
.
clear
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment