Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d2fb49af
You need to sign in or sign up before continuing.
Commit
d2fb49af
authored
Aug 15, 2025
by
helloyongyang
Browse files
fix tea cache
parent
cc04b3fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
21 deletions
+17
-21
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+11
-19
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+6
-2
No files found.
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
d2fb49af
...
@@ -43,6 +43,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -43,6 +43,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
# calculate should_calc
# calculate should_calc
@
torch
.
no_grad
()
def
calculate_should_calc
(
self
,
embed
,
embed0
):
def
calculate_should_calc
(
self
,
embed
,
embed0
):
# 1. timestep embedding
# 1. timestep embedding
modulated_inp
=
embed0
if
self
.
use_ret_steps
else
embed
modulated_inp
=
embed0
if
self
.
use_ret_steps
else
embed
...
@@ -95,30 +96,30 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -95,30 +96,30 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
# 3. return the judgement
# 3. return the judgement
return
should_calc
return
should_calc
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
contex
t
):
def
infer
_main_blocks
(
self
,
weights
,
pre_infer_ou
t
):
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
caching_records
=
self
.
scheduler
.
caching_records
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
should_calc
=
self
.
calculate_should_calc
(
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
)
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
if
caching_records
[
index
]
or
self
.
must_calc
(
index
):
if
caching_records
[
index
]
or
self
.
must_calc
(
index
):
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
contex
t
)
x
=
self
.
infer_calculating
(
weights
,
pre_infer_ou
t
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
pre_infer_out
.
x
)
else
:
else
:
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records_2
=
self
.
scheduler
.
caching_records_2
caching_records_2
=
self
.
scheduler
.
caching_records_2
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
should_calc
=
self
.
calculate_should_calc
(
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
)
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
):
if
caching_records_2
[
index
]
or
self
.
must_calc
(
index
):
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
contex
t
)
x
=
self
.
infer_calculating
(
weights
,
pre_infer_ou
t
)
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
pre_infer_out
.
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
.
enable_cfg
:
self
.
switch_status
()
self
.
switch_status
()
...
@@ -131,19 +132,10 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -131,19 +132,10 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return
x
return
x
def
infer_calculating
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
contex
t
):
def
infer_calculating
(
self
,
weights
,
pre_infer_ou
t
):
ori_x
=
x
.
clone
()
ori_x
=
pre_infer_out
.
x
.
clone
()
x
=
super
().
infer
(
x
=
super
().
infer_main_blocks
(
weights
,
pre_infer_out
)
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
)
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
previous_residual_even
=
x
-
ori_x
self
.
previous_residual_even
=
x
-
ori_x
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
d2fb49af
...
@@ -104,6 +104,10 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -104,6 +104,10 @@ class WanTransformerInfer(BaseTransformerInfer):
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_main_blocks
(
weights
,
pre_infer_out
)
return
self
.
infer_post_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
def
infer_main_blocks
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_func
(
x
=
self
.
infer_func
(
weights
,
weights
,
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
grid_sizes
,
...
@@ -115,9 +119,9 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -115,9 +119,9 @@ class WanTransformerInfer(BaseTransformerInfer):
pre_infer_out
.
context
,
pre_infer_out
.
context
,
pre_infer_out
.
audio_dit_blocks
,
pre_infer_out
.
audio_dit_blocks
,
)
)
return
self
.
_infer_post_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
return
x
def
_
infer_post_blocks
(
self
,
weights
,
x
,
e
):
def
infer_post_blocks
(
self
,
weights
,
x
,
e
):
if
e
.
dim
()
==
2
:
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment