Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d2fb49af
Commit
d2fb49af
authored
Aug 15, 2025
by
helloyongyang
Browse files
fix tea cache
parent
cc04b3fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
21 deletions
+17
-21
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+11
-19
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+6
-2
No files found.
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
d2fb49af
...
@@ -43,6 +43,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -43,6 +43,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
# calculate should_calc
# calculate should_calc
@
torch
.
no_grad
()
def
calculate_should_calc
(
self
,
embed
,
embed0
):
def
calculate_should_calc
(
self
,
embed
,
embed0
):
# 1. timestep embedding
# 1. timestep embedding
modulated_inp
=
embed0
if
self
.
use_ret_steps
else
embed
modulated_inp
=
embed0
if
self
.
use_ret_steps
else
embed
...
@@ -95,30 +96,30 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -95,30 +96,30 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
# 3. return the judgement
# 3. return the judgement
return
should_calc
return
should_calc
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
contex
t
):
def
infer
_main_blocks
(
self
,
weights
,
pre_infer_ou
t
):
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
caching_records
=
self
.
scheduler
.
caching_records
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
should_calc
=
self
.
calculate_should_calc
(
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
)
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
if
caching_records
[
index
]
or
self
.
must_calc
(
index
):
if
caching_records
[
index
]
or
self
.
must_calc
(
index
):
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
contex
t
)
x
=
self
.
infer_calculating
(
weights
,
pre_infer_ou
t
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
pre_infer_out
.
x
)
else
:
else
:
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records_2
=
self
.
scheduler
.
caching_records_2
caching_records_2
=
self
.
scheduler
.
caching_records_2
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
should_calc
=
self
.
calculate_should_calc
(
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
)
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
):
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
):
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
contex
t
)
x
=
self
.
infer_calculating
(
weights
,
pre_infer_ou
t
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
pre_infer_out
.
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
.
enable_cfg
:
self
.
switch_status
()
self
.
switch_status
()
...
@@ -131,19 +132,10 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -131,19 +132,10 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return
x
return
x
def
infer_calculating
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
contex
t
):
def
infer_calculating
(
self
,
weights
,
pre_infer_ou
t
):
ori_x
=
x
.
clone
()
ori_x
=
pre_infer_out
.
x
.
clone
()
x
=
super
().
infer
(
x
=
super
().
infer_main_blocks
(
weights
,
pre_infer_out
)
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
)
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
previous_residual_even
=
x
-
ori_x
self
.
previous_residual_even
=
x
-
ori_x
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
d2fb49af
...
@@ -104,6 +104,10 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -104,6 +104,10 @@ class WanTransformerInfer(BaseTransformerInfer):
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_main_blocks
(
weights
,
pre_infer_out
)
return
self
.
infer_post_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
def
infer_main_blocks
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_func
(
x
=
self
.
infer_func
(
weights
,
weights
,
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
grid_sizes
,
...
@@ -115,9 +119,9 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -115,9 +119,9 @@ class WanTransformerInfer(BaseTransformerInfer):
pre_infer_out
.
context
,
pre_infer_out
.
context
,
pre_infer_out
.
audio_dit_blocks
,
pre_infer_out
.
audio_dit_blocks
,
)
)
return
self
.
_infer_post_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
return
x
def
_
infer_post_blocks
(
self
,
weights
,
x
,
e
):
def
infer_post_blocks
(
self
,
weights
,
x
,
e
):
if
e
.
dim
()
==
2
:
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment