Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ad051778
Commit
ad051778
authored
Jul 01, 2025
by
gushiqiao
Committed by
GitHub
Jul 01, 2025
Browse files
Fix (#80)
* Fix * Fix * Fix --------- Co-authored-by:
gushiqiao
<
gushiqiao@sensetime.com
>
parent
fb69083e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
303 additions
and
255 deletions
+303
-255
lightx2v/models/networks/wan/distill_model.py
lightx2v/models/networks/wan/distill_model.py
+1
-1
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+89
-162
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+11
-5
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+7
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+112
-49
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+33
-0
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+13
-25
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+37
-11
No files found.
lightx2v/models/networks/wan/distill_model.py
View file @
ad051778
...
@@ -24,7 +24,7 @@ class WanDistillModel(WanModel):
...
@@ -24,7 +24,7 @@ class WanDistillModel(WanModel):
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"distill_model.pt"
)
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"distill_model.pt"
)
if
not
os
.
path
.
exists
(
ckpt_path
):
if
not
os
.
path
.
exists
(
ckpt_path
):
# 文件不存在,调用父类的 _load_ckpt 方法
# 文件不存在,调用父类的 _load_ckpt 方法
return
super
().
_load_ckpt
()
return
super
().
_load_ckpt
(
use_bf16
,
skip_bf16
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
weight_dict
=
{
key
:
(
weight_dict
[
key
].
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
skip_bf16
)
else
weight_dict
[
key
]).
pin_memory
().
to
(
self
.
device
)
for
key
in
weight_dict
.
keys
()}
weight_dict
=
{
key
:
(
weight_dict
[
key
].
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
skip_bf16
)
else
weight_dict
[
key
]).
pin_memory
().
to
(
self
.
device
)
for
key
in
weight_dict
.
keys
()}
...
...
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
ad051778
...
@@ -2,13 +2,14 @@ from ..transformer_infer import WanTransformerInfer
...
@@ -2,13 +2,14 @@ from ..transformer_infer import WanTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTaylorCachingTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTaylorCachingTransformerInfer
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
gc
# 1. TeaCaching
# 1. TeaCaching
class
WanTransformerInferTeaCaching
(
WanTransformerInfer
):
class
WanTransformerInferTeaCaching
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
self
.
previous_e0_even
=
None
...
@@ -16,71 +17,18 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
...
@@ -16,71 +17,18 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
previous_e0_odd
=
None
self
.
previous_e0_odd
=
None
self
.
previous_residual_odd
=
None
self
.
previous_residual_odd
=
None
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
set_attributes_by_task_and_model
()
if
self
.
use_ret_steps
:
self
.
cnt
=
0
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
ret_steps
=
5
*
2
# only in Wan2.1 TeaCaching
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
def
set_attributes_by_task_and_model
(
self
):
else
:
if
self
.
config
.
task
==
"i2v"
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
if
self
.
use_ret_steps
:
self
.
ret_steps
=
1
*
2
if
self
.
config
.
target_width
==
480
or
self
.
config
.
target_height
==
480
:
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
self
.
coefficients
=
[
2.57151496e05
,
-
3.54229917e04
,
1.40286849e03
,
-
1.35890334e01
,
1.32517977e-01
,
]
if
self
.
config
.
target_width
==
720
or
self
.
config
.
target_height
==
720
:
self
.
coefficients
=
[
8.10705460e03
,
2.13393892e03
,
-
3.72934672e02
,
1.66203073e01
,
-
4.17769401e-02
,
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
else
:
if
self
.
config
.
target_width
==
480
or
self
.
config
.
target_height
==
480
:
self
.
coefficients
=
[
-
3.02331670e02
,
2.23948934e02
,
-
5.25463970e01
,
5.87348440e00
,
-
2.01973289e-01
,
]
if
self
.
config
.
target_width
==
720
or
self
.
config
.
target_height
==
720
:
self
.
coefficients
=
[
-
114.36346466
,
65.26524496
,
-
18.82220707
,
4.91518089
,
-
0.23412683
,
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
elif
self
.
config
.
task
==
"t2v"
:
if
self
.
use_ret_steps
:
if
"1.3B"
in
self
.
config
.
model_path
:
self
.
coefficients
=
[
-
5.21862437e04
,
9.23041404e03
,
-
5.28275948e02
,
1.36987616e01
,
-
4.99875664e-02
]
if
"14B"
in
self
.
config
.
model_path
:
self
.
coefficients
=
[
-
3.03318725e05
,
4.90537029e04
,
-
2.65530556e03
,
5.87365115e01
,
-
3.15583525e-01
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
else
:
if
"1.3B"
in
self
.
config
.
model_path
:
self
.
coefficients
=
[
2.39676752e03
,
-
1.31110545e03
,
2.01331979e02
,
-
8.29855975e00
,
1.37887774e-01
]
if
"14B"
in
self
.
config
.
model_path
:
self
.
coefficients
=
[
-
5784.54975374
,
5449.50911966
,
-
1811.16591783
,
256.27178429
,
-
13.02252404
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
# calculate should_calc
# calculate should_calc
def
calculate_should_calc
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
calculate_should_calc
(
self
,
embed
,
embed0
):
# 1. timestep embedding
# 1. timestep embedding
modulated_inp
=
embed0
if
self
.
use_ret_steps
else
embed
modulated_inp
=
embed0
if
self
.
use_ret_steps
else
embed
...
@@ -92,13 +40,15 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
...
@@ -92,13 +40,15 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_even
=
0
else
:
else
:
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
self
.
accumulated_rel_l1_distance_even
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_even
).
abs
().
mean
()
/
self
.
previous_e0_even
.
abs
().
mean
()).
cpu
().
item
())
self
.
accumulated_rel_l1_distance_even
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_even
.
cuda
()
).
abs
().
mean
()
/
self
.
previous_e0_even
.
cuda
().
abs
().
mean
()).
cpu
().
item
())
if
self
.
accumulated_rel_l1_distance_even
<
self
.
teacache_thresh
:
if
self
.
accumulated_rel_l1_distance_even
<
self
.
teacache_thresh
:
should_calc
=
False
should_calc
=
False
else
:
else
:
should_calc
=
True
should_calc
=
True
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
modulated_inp
.
clone
()
self
.
previous_e0_even
=
modulated_inp
.
clone
()
if
self
.
config
[
"cpu_offload"
]:
self
.
previous_e0_even
=
self
.
previous_e0_even
.
cpu
()
else
:
else
:
if
self
.
cnt
<
self
.
ret_steps
or
self
.
cnt
>=
self
.
cutoff_steps
:
if
self
.
cnt
<
self
.
ret_steps
or
self
.
cnt
>=
self
.
cutoff_steps
:
...
@@ -106,7 +56,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
...
@@ -106,7 +56,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
accumulated_rel_l1_distance_odd
=
0
else
:
else
:
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
self
.
accumulated_rel_l1_distance_odd
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_odd
).
abs
().
mean
()
/
self
.
previous_e0_odd
.
abs
().
mean
()).
cpu
().
item
())
self
.
accumulated_rel_l1_distance_odd
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_odd
.
cuda
()
).
abs
().
mean
()
/
self
.
previous_e0_odd
.
cuda
().
abs
().
mean
()).
cpu
().
item
())
if
self
.
accumulated_rel_l1_distance_odd
<
self
.
teacache_thresh
:
if
self
.
accumulated_rel_l1_distance_odd
<
self
.
teacache_thresh
:
should_calc
=
False
should_calc
=
False
else
:
else
:
...
@@ -114,6 +64,19 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
...
@@ -114,6 +64,19 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
previous_e0_odd
=
modulated_inp
.
clone
()
self
.
previous_e0_odd
=
modulated_inp
.
clone
()
if
self
.
config
[
"cpu_offload"
]:
self
.
previous_e0_odd
=
self
.
previous_e0_odd
.
cpu
()
if
self
.
config
[
"cpu_offload"
]:
modulated_inp
=
modulated_inp
.
cpu
()
del
modulated_inp
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
if
self
.
clean_cuda_cache
:
del
embed
,
embed0
torch
.
cuda
.
empty_cache
()
# 3. return the judgement
# 3. return the judgement
return
should_calc
return
should_calc
...
@@ -122,54 +85,71 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
...
@@ -122,54 +85,71 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records
=
self
.
scheduler
.
caching_records
caching_records
=
self
.
scheduler
.
caching_records
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
should_calc
=
self
.
calculate_should_calc
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
self
.
scheduler
.
caching_records
[
index
]
=
should_calc
if
caching_records
[
index
]:
if
caching_records
[
index
]:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_using_cache
(
x
)
else
:
else
:
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
caching_records_2
=
self
.
scheduler
.
caching_records_2
caching_records_2
=
self
.
scheduler
.
caching_records_2
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
if
index
<=
self
.
scheduler
.
infer_steps
-
1
:
should_calc
=
self
.
calculate_should_calc
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
should_calc
=
self
.
calculate_should_calc
(
embed
,
embed0
)
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
self
.
scheduler
.
caching_records_2
[
index
]
=
should_calc
if
caching_records_2
[
index
]:
if
caching_records_2
[
index
]:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_using_cache
(
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
.
enable_cfg
:
self
.
switch_status
()
self
.
switch_status
()
self
.
cnt
+=
1
self
.
cnt
+=
1
if
self
.
clean_cuda_cache
:
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
torch
.
cuda
.
empty_cache
()
return
x
return
x
def
infer_calculating
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_calculating
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
ori_x
=
x
.
clone
()
ori_x
=
x
.
clone
()
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
super
().
infer
(
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_phase_1
(
weights
.
blocks
[
block_idx
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
weights
,
y_out
=
self
.
infer_phase_2
(
weights
.
blocks
[
block_idx
].
compute_phases
[
0
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
grid_sizes
,
attn_out
=
self
.
infer_phase_3
(
weights
.
blocks
[
block_idx
].
compute_phases
[
1
],
x
,
context
,
y_out
,
gate_msa
)
embed
,
y_out
=
self
.
infer_phase_4
(
weights
.
blocks
[
block_idx
].
compute_phases
[
2
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
,
x
=
self
.
infer_phase_5
(
x
,
y_out
,
c_gate_msa
)
embed0
,
seq_lens
,
freqs
,
context
,
)
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
previous_residual_even
=
x
-
ori_x
self
.
previous_residual_even
=
x
-
ori_x
if
self
.
config
[
"cpu_offload"
]:
self
.
previous_residual_even
=
self
.
previous_residual_even
.
cpu
()
else
:
else
:
self
.
previous_residual_odd
=
x
-
ori_x
self
.
previous_residual_odd
=
x
-
ori_x
if
self
.
config
[
"cpu_offload"
]:
self
.
previous_residual_odd
=
self
.
previous_residual_odd
.
cpu
()
if
self
.
config
[
"cpu_offload"
]:
ori_x
=
ori_x
.
to
(
"cpu"
)
del
ori_x
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
x
return
x
def
infer_using_cache
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_using_cache
(
self
,
x
):
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
x
+=
self
.
previous_residual_even
x
.
add_
(
self
.
previous_residual_even
.
cuda
())
else
:
else
:
x
+=
self
.
previous_residual_odd
x
.
add_
(
self
.
previous_residual_odd
.
cuda
())
return
x
return
x
def
clear
(
self
):
def
clear
(
self
):
...
@@ -241,27 +221,27 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -241,27 +221,27 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
def
infer_calculating
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_calculating
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_
phase_1
(
weights
.
blocks
[
block_idx
]
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_
modulation
(
weights
.
blocks
[
block_idx
]
.
compute_phases
[
0
],
embed0
)
y_out
=
self
.
infer_
phase_2
(
weights
.
blocks
[
block_idx
].
compute_phases
[
0
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
y_out
=
self
.
infer_
self_attn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
1
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
derivative_approximation
(
self
.
blocks_cache_even
[
block_idx
],
"self_attn_out"
,
y_out
)
self
.
derivative_approximation
(
self
.
blocks_cache_even
[
block_idx
],
"self_attn_out"
,
y_out
)
else
:
else
:
self
.
derivative_approximation
(
self
.
blocks_cache_odd
[
block_idx
],
"self_attn_out"
,
y_out
)
self
.
derivative_approximation
(
self
.
blocks_cache_odd
[
block_idx
],
"self_attn_out"
,
y_out
)
attn_out
=
self
.
infer_
phase_3
(
weights
.
blocks
[
block_idx
].
compute_phases
[
1
],
x
,
context
,
y_out
,
gate_msa
)
attn_out
=
self
.
infer_
cross_attn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
2
],
x
,
context
,
y_out
,
gate_msa
)
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
derivative_approximation
(
self
.
blocks_cache_even
[
block_idx
],
"cross_attn_out"
,
attn_out
)
self
.
derivative_approximation
(
self
.
blocks_cache_even
[
block_idx
],
"cross_attn_out"
,
attn_out
)
else
:
else
:
self
.
derivative_approximation
(
self
.
blocks_cache_odd
[
block_idx
],
"cross_attn_out"
,
attn_out
)
self
.
derivative_approximation
(
self
.
blocks_cache_odd
[
block_idx
],
"cross_attn_out"
,
attn_out
)
y_out
=
self
.
infer_
phase_4
(
weights
.
blocks
[
block_idx
].
compute_phases
[
2
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
y_out
=
self
.
infer_
ffn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
derivative_approximation
(
self
.
blocks_cache_even
[
block_idx
],
"ffn_out"
,
y_out
)
self
.
derivative_approximation
(
self
.
blocks_cache_even
[
block_idx
],
"ffn_out"
,
y_out
)
else
:
else
:
self
.
derivative_approximation
(
self
.
blocks_cache_odd
[
block_idx
],
"ffn_out"
,
y_out
)
self
.
derivative_approximation
(
self
.
blocks_cache_odd
[
block_idx
],
"ffn_out"
,
y_out
)
x
=
self
.
infer_phase_5
(
x
,
y_out
,
c_gate_msa
)
x
=
self
.
post_process
(
x
,
y_out
,
c_gate_msa
)
return
x
return
x
def
infer_using_cache
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_using_cache
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
...
@@ -272,7 +252,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -272,7 +252,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
# 1. taylor using caching
# 1. taylor using caching
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
i
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
i
):
# 1. shift, scale, gate
# 1. shift, scale, gate
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_phase_1
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
_
,
_
,
gate_msa
,
_
,
_
,
c_gate_msa
=
self
.
infer_modulation
(
weights
,
embed0
)
# 2. residual and taylor
# 2. residual and taylor
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
...
@@ -369,7 +349,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
...
@@ -369,7 +349,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
if
(
index
+
i
)
<=
self
.
scheduler
.
infer_steps
-
1
:
if
(
index
+
i
)
<=
self
.
scheduler
.
infer_steps
-
1
:
self
.
scheduler
.
caching_records_2
[
index
+
i
]
=
False
self
.
scheduler
.
caching_records_2
[
index
+
i
]
=
False
else
:
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
conte
xt
)
x
=
self
.
infer_using_cache
(
xt
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
.
enable_cfg
:
self
.
switch_status
()
self
.
switch_status
()
...
@@ -380,18 +360,18 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
...
@@ -380,18 +360,18 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
ori_x
=
x
.
clone
()
ori_x
=
x
.
clone
()
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_
phase_1
(
weights
.
blocks
[
block_idx
]
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_
modulation
(
weights
.
blocks
[
block_idx
]
.
compute_phases
[
0
],
embed0
)
y_out
=
self
.
infer_
phase_2
(
weights
.
blocks
[
block_idx
].
compute_phases
[
0
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
y_out
=
self
.
infer_
self_attn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
1
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
if
block_idx
==
self
.
decisive_double_block_id
:
if
block_idx
==
self
.
decisive_double_block_id
:
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
args_even
.
now_residual_tiny
=
y_out
*
gate_msa
.
squeeze
(
0
)
self
.
args_even
.
now_residual_tiny
=
y_out
*
gate_msa
.
squeeze
(
0
)
else
:
else
:
self
.
args_odd
.
now_residual_tiny
=
y_out
*
gate_msa
.
squeeze
(
0
)
self
.
args_odd
.
now_residual_tiny
=
y_out
*
gate_msa
.
squeeze
(
0
)
attn_out
=
self
.
infer_
phase_3
(
weights
.
blocks
[
block_idx
].
compute_phases
[
1
],
x
,
context
,
y_out
,
gate_msa
)
attn_out
=
self
.
infer_
cross_attn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
2
],
x
,
context
,
y_out
,
gate_msa
)
y_out
=
self
.
infer_
phase_4
(
weights
.
blocks
[
block_idx
].
compute_phases
[
2
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
y_out
=
self
.
infer_
ffn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
infer_phase_5
(
x
,
y_out
,
c_gate_msa
)
x
=
self
.
post_process
(
x
,
y_out
,
c_gate_msa
)
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
args_even
.
previous_residual
=
x
-
ori_x
self
.
args_even
.
previous_residual
=
x
-
ori_x
...
@@ -399,7 +379,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
...
@@ -399,7 +379,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
self
.
args_odd
.
previous_residual
=
x
-
ori_x
self
.
args_odd
.
previous_residual
=
x
-
ori_x
return
x
return
x
def
infer_using_cache
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_using_cache
(
self
,
x
):
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
x
+=
self
.
args_even
.
previous_residual
x
+=
self
.
args_even
.
previous_residual
else
:
else
:
...
@@ -542,7 +522,7 @@ class AdaArgs:
...
@@ -542,7 +522,7 @@ class AdaArgs:
class
WanTransformerInferCustomCaching
(
WanTransformerInfer
,
BaseTaylorCachingTransformerInfer
):
class
WanTransformerInferCustomCaching
(
WanTransformerInfer
,
BaseTaylorCachingTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
self
.
previous_e0_even
=
None
...
@@ -552,68 +532,15 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -552,68 +532,15 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
self
.
previous_residual_odd
=
None
self
.
previous_residual_odd
=
None
self
.
cache_even
=
{}
self
.
cache_even
=
{}
self
.
cache_odd
=
{}
self
.
cache_odd
=
{}
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
set_attributes_by_task_and_model
()
if
self
.
use_ret_steps
:
self
.
cnt
=
0
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
ret_steps
=
5
*
2
# only in Wan2.1 TeaCaching
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
def
set_attributes_by_task_and_model
(
self
):
else
:
if
self
.
config
.
task
==
"i2v"
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
if
self
.
use_ret_steps
:
self
.
ret_steps
=
1
*
2
if
self
.
config
.
target_width
==
480
or
self
.
config
.
target_height
==
480
:
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
self
.
coefficients
=
[
2.57151496e05
,
-
3.54229917e04
,
1.40286849e03
,
-
1.35890334e01
,
1.32517977e-01
,
]
if
self
.
config
.
target_width
==
720
or
self
.
config
.
target_height
==
720
:
self
.
coefficients
=
[
8.10705460e03
,
2.13393892e03
,
-
3.72934672e02
,
1.66203073e01
,
-
4.17769401e-02
,
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
else
:
if
self
.
config
.
target_width
==
480
or
self
.
config
.
target_height
==
480
:
self
.
coefficients
=
[
-
3.02331670e02
,
2.23948934e02
,
-
5.25463970e01
,
5.87348440e00
,
-
2.01973289e-01
,
]
if
self
.
config
.
target_width
==
720
or
self
.
config
.
target_height
==
720
:
self
.
coefficients
=
[
-
114.36346466
,
65.26524496
,
-
18.82220707
,
4.91518089
,
-
0.23412683
,
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
elif
self
.
config
.
task
==
"t2v"
:
if
self
.
use_ret_steps
:
if
"1.3B"
in
self
.
config
.
model_path
:
self
.
coefficients
=
[
-
5.21862437e04
,
9.23041404e03
,
-
5.28275948e02
,
1.36987616e01
,
-
4.99875664e-02
]
if
"14B"
in
self
.
config
.
model_path
:
self
.
coefficients
=
[
-
3.03318725e05
,
4.90537029e04
,
-
2.65530556e03
,
5.87365115e01
,
-
3.15583525e-01
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
else
:
if
"1.3B"
in
self
.
config
.
model_path
:
self
.
coefficients
=
[
2.39676752e03
,
-
1.31110545e03
,
2.01331979e02
,
-
8.29855975e00
,
1.37887774e-01
]
if
"14B"
in
self
.
config
.
model_path
:
self
.
coefficients
=
[
-
5784.54975374
,
5449.50911966
,
-
1811.16591783
,
256.27178429
,
-
13.02252404
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
# 1. get taylor step_diff when there is two caching_records in scheduler
# 1. get taylor step_diff when there is two caching_records in scheduler
def
get_taylor_step_diff
(
self
):
def
get_taylor_step_diff
(
self
):
...
@@ -682,7 +609,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -682,7 +609,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
if
caching_records
[
index
]:
if
caching_records
[
index
]:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_using_cache
(
x
)
else
:
else
:
index
=
self
.
scheduler
.
step_index
index
=
self
.
scheduler
.
step_index
...
@@ -694,7 +621,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -694,7 +621,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
if
caching_records_2
[
index
]:
if
caching_records_2
[
index
]:
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_calculating
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
else
:
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_using_cache
(
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
.
enable_cfg
:
self
.
switch_status
()
self
.
switch_status
()
...
@@ -707,12 +634,12 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -707,12 +634,12 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
ori_x
=
x
.
clone
()
ori_x
=
x
.
clone
()
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_
phase_1
(
weights
.
blocks
[
block_idx
]
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_
modulation
(
weights
.
blocks
[
block_idx
]
.
compute_phases
[
0
],
embed0
)
y_out
=
self
.
infer_
phase_2
(
weights
.
blocks
[
block_idx
].
compute_phases
[
0
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
y_out
=
self
.
infer_
self_attn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
1
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
attn_out
=
self
.
infer_
phase_3
(
weights
.
blocks
[
block_idx
].
compute_phases
[
1
],
x
,
context
,
y_out
,
gate_msa
)
attn_out
=
self
.
infer_
cross_attn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
2
],
x
,
context
,
y_out
,
gate_msa
)
y_out
=
self
.
infer_
phase_4
(
weights
.
blocks
[
block_idx
].
compute_phases
[
2
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
y_out
=
self
.
infer_
ffn
(
weights
.
blocks
[
block_idx
].
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
infer_phase_5
(
x
,
y_out
,
c_gate_msa
)
x
=
self
.
post_process
(
x
,
y_out
,
c_gate_msa
)
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
self
.
previous_residual_even
=
x
-
ori_x
self
.
previous_residual_even
=
x
-
ori_x
...
@@ -722,7 +649,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
...
@@ -722,7 +649,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
self
.
derivative_approximation
(
self
.
cache_odd
,
"previous_residual"
,
self
.
previous_residual_odd
)
self
.
derivative_approximation
(
self
.
cache_odd
,
"previous_residual"
,
self
.
previous_residual_odd
)
return
x
return
x
def
infer_using_cache
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_using_cache
(
self
,
x
):
if
self
.
infer_conditional
:
if
self
.
infer_conditional
:
x
+=
self
.
taylor_formula
(
self
.
cache_even
[
"previous_residual"
])
x
+=
self
.
taylor_formula
(
self
.
cache_even
[
"previous_residual"
])
else
:
else
:
...
...
lightx2v/models/networks/wan/infer/post_infer.py
View file @
ad051778
...
@@ -8,6 +8,7 @@ class WanPostInfer:
...
@@ -8,6 +8,7 @@ class WanPostInfer:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
out_dim
=
config
[
"out_dim"
]
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
@@ -21,16 +22,21 @@ class WanPostInfer:
...
@@ -21,16 +22,21 @@ class WanPostInfer:
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
norm_out
=
weights
.
norm
.
apply
(
x
)
x
=
weights
.
norm
.
apply
(
x
)
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
norm_out
=
norm_out
.
float
()
x
=
x
.
float
()
out
=
norm_out
*
(
1
+
e
[
1
].
squeeze
(
0
))
+
e
[
0
].
squeeze
(
0
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
(
0
))
.
add_
(
e
[
0
].
squeeze
(
0
)
)
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
out
=
out
.
to
(
torch
.
bfloat16
)
x
=
x
.
to
(
torch
.
bfloat16
)
x
=
weights
.
head
.
apply
(
out
)
x
=
weights
.
head
.
apply
(
x
)
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
if
self
.
clean_cuda_cache
:
del
e
,
grid_sizes
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
return
[
u
.
float
()
for
u
in
x
]
def
unpatchify
(
self
,
x
,
grid_sizes
):
def
unpatchify
(
self
,
x
,
grid_sizes
):
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
ad051778
...
@@ -7,7 +7,7 @@ class WanPreInfer:
...
@@ -7,7 +7,7 @@ class WanPreInfer:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
freqs
=
torch
.
cat
(
self
.
freqs
=
torch
.
cat
(
[
[
...
@@ -87,6 +87,9 @@ class WanPreInfer:
...
@@ -87,6 +87,9 @@ class WanPreInfer:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
if
self
.
clean_cuda_cache
:
del
out
,
stacked
torch
.
cuda
.
empty_cache
()
if
self
.
task
==
"i2v"
:
if
self
.
task
==
"i2v"
:
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
...
@@ -95,7 +98,9 @@ class WanPreInfer:
...
@@ -95,7 +98,9 @@ class WanPreInfer:
context_clip
=
weights
.
proj_3
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_3
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_4
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_4
.
apply
(
context_clip
)
context
=
torch
.
concat
([
context_clip
,
context
],
dim
=
0
)
context
=
torch
.
concat
([
context_clip
,
context
],
dim
=
0
)
if
self
.
clean_cuda_cache
:
del
context_clip
,
clip_fea
torch
.
cuda
.
empty_cache
()
return
(
return
(
embed
,
embed
,
grid_sizes
,
grid_sizes
,
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
ad051778
import
torch
import
torch
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
lightx2v.common.offload.manager
import
(
from
lightx2v.common.offload.manager
import
(
WeightAsyncStreamManager
,
WeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
...
@@ -14,11 +14,13 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -14,11 +14,13 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
phases_num
=
3
self
.
phases_num
=
4
self
.
num_heads
=
config
[
"num_heads"
]
self
.
num_heads
=
config
[
"num_heads"
]
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
self
.
parallel_attention
=
None
self
.
apply_rotary_emb_func
=
apply_rotary_emb_chunk
if
config
.
get
(
"rotary_chunk"
,
False
)
else
apply_rotary_emb
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
if
"offload_ratio"
in
self
.
config
:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
offload_ratio
=
self
.
config
[
"offload_ratio"
]
...
@@ -92,10 +94,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -92,10 +94,6 @@ class WanTransformerInfer(BaseTransformerInfer):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
weights
.
blocks_num
):
for
block_idx
in
range
(
weights
.
blocks_num
):
weights
.
blocks
[
block_idx
].
modulation
.
to_cuda
()
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_phase_1
(
weights
.
blocks
[
block_idx
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
for
phase_idx
in
range
(
self
.
phases_num
):
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
if
block_idx
==
0
and
phase_idx
==
0
:
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
...
@@ -105,12 +103,23 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -105,12 +103,23 @@ class WanTransformerInfer(BaseTransformerInfer):
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
cur_phase_idx
,
cur_phase
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
cur_phase_idx
,
cur_phase
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
if
cur_phase_idx
==
0
:
y_out
=
self
.
infer_phase_2
(
cur_phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
cur_phase
,
embed0
)
elif
cur_phase_idx
==
1
:
elif
cur_phase_idx
==
1
:
attn_out
=
self
.
infer_phase_3
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
y_out
=
self
.
infer_self_attn
(
cur_phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
,
)
elif
cur_phase_idx
==
2
:
elif
cur_phase_idx
==
2
:
y
=
self
.
infer_phase_4
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
x
=
self
.
infer_phase_5
(
x
,
y
,
c_gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
if
not
is_last_phase
:
...
@@ -120,8 +129,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -120,8 +129,6 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
swap_phases
()
weights
.
blocks
[
block_idx
].
modulation
.
to_cpu
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
x
return
x
...
@@ -130,11 +137,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -130,11 +137,6 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
for
block_idx
in
range
(
weights
.
blocks_num
):
for
block_idx
in
range
(
weights
.
blocks_num
):
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
weights
.
blocks
[
block_idx
].
modulation
.
to_cuda
()
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_phase_1
(
weights
.
blocks
[
block_idx
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
if
block_idx
==
0
and
phase_idx
==
0
:
obj_key
=
(
block_idx
,
phase_idx
)
obj_key
=
(
block_idx
,
phase_idx
)
...
@@ -152,12 +154,25 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -152,12 +154,25 @@ class WanTransformerInfer(BaseTransformerInfer):
)
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
)
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
if
cur_phase_idx
==
0
:
y_out
=
self
.
infer_phase_2
(
cur_phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
cur_phase
,
embed0
,
)
elif
cur_phase_idx
==
1
:
elif
cur_phase_idx
==
1
:
attn_out
=
self
.
infer_phase_3
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
y_out
=
self
.
infer_self_attn
(
cur_phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
,
)
elif
cur_phase_idx
==
2
:
elif
cur_phase_idx
==
2
:
y
=
self
.
infer_phase_4
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
x
=
self
.
infer_phase_5
(
x
,
y
,
c_gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
...
@@ -166,10 +181,16 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -166,10 +181,16 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
swap_phases
()
weights
.
blocks
[
block_idx
].
modulation
.
to_cpu
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
)
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
)
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
del
attn_out
,
y_out
,
y
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
del
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
torch
.
cuda
.
empty_cache
()
return
x
return
x
...
@@ -188,36 +209,51 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -188,36 +209,51 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_phase_1
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
y_out
=
self
.
infer_phase_2
(
weights
.
compute_phases
[
0
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
weights
.
compute_phases
[
0
],
attn_out
=
self
.
infer_phase_3
(
weights
.
compute_phases
[
1
],
x
,
context
,
y_out
,
gate_msa
)
embed0
,
y
=
self
.
infer_phase_4
(
weights
.
compute_phases
[
2
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
)
x
=
self
.
infer_phase_5
(
x
,
y
,
c_gate_msa
)
y_out
=
self
.
infer_self_attn
(
weights
.
compute_phases
[
1
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
,
)
attn_out
=
self
.
infer_cross_attn
(
weights
.
compute_phases
[
2
],
x
,
context
,
y_out
,
gate_msa
)
y
=
self
.
infer_ffn
(
weights
.
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
return
x
return
x
def
infer_
phase_1
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_
modulation
(
self
,
weights
,
embed0
):
if
embed0
.
dim
()
==
3
:
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
elif
embed0
.
dim
()
==
2
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
if
self
.
clean_cuda_cache
:
del
embed0
torch
.
cuda
.
empty_cache
()
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
infer_
phase_2
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
def
infer_
self_attn
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
)
*
weights
.
smooth_norm1_weight
.
tensor
norm1_weight
=
(
1
+
scale_msa
.
squeeze
(
0
)
)
*
weights
.
smooth_norm1_weight
.
tensor
norm1_bias
=
shift_msa
*
weights
.
smooth_norm1_bias
.
tensor
norm1_bias
=
shift_msa
.
squeeze
(
0
)
*
weights
.
smooth_norm1_bias
.
tensor
else
:
else
:
norm1_weight
=
1
+
scale_msa
norm1_weight
=
1
+
scale_msa
.
squeeze
(
0
)
norm1_bias
=
shift_msa
norm1_bias
=
shift_msa
.
squeeze
(
0
)
norm1_out
=
weights
.
norm1
.
apply
(
x
)
norm1_out
=
weights
.
norm1
.
apply
(
x
)
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
norm1_out
=
norm1_out
.
float
()
norm1_out
=
norm1_out
.
float
()
norm1_out
=
(
norm1_out
*
norm1_weight
+
norm1_bias
)
.
squeeze
(
0
)
norm1_out
.
mul_
(
norm1_weight
).
add_
(
norm1_bias
)
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
norm1_out
=
norm1_out
.
to
(
torch
.
bfloat16
)
norm1_out
=
norm1_out
.
to
(
torch
.
bfloat16
)
...
@@ -233,8 +269,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -233,8 +269,8 @@ class WanTransformerInfer(BaseTransformerInfer):
else
:
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
q
=
apply_rotary_emb
(
q
,
freqs_i
)
q
=
self
.
apply_rotary_emb
_func
(
q
,
freqs_i
)
k
=
apply_rotary_emb
(
k
,
freqs_i
)
k
=
self
.
apply_rotary_emb
_func
(
k
,
freqs_i
)
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
seq_lens
)
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
seq_lens
)
...
@@ -260,9 +296,14 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -260,9 +296,14 @@ class WanTransformerInfer(BaseTransformerInfer):
)
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
if
self
.
clean_cuda_cache
:
del
q
,
k
,
v
,
attn_out
,
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
torch
.
cuda
.
empty_cache
()
return
y
return
y
def
infer_
phase_3
(
self
,
weights
,
x
,
context
,
y_out
,
gate_msa
):
def
infer_
cross_attn
(
self
,
weights
,
x
,
context
,
y_out
,
gate_msa
):
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
x
=
x
.
float
()
+
y_out
.
float
()
*
gate_msa
.
squeeze
(
0
)
x
=
x
.
float
()
+
y_out
.
float
()
*
gate_msa
.
squeeze
(
0
)
else
:
else
:
...
@@ -319,13 +360,26 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -319,13 +360,26 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_kv
=
k_img
.
size
(
0
),
max_seqlen_kv
=
k_img
.
size
(
0
),
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
attn_out
=
attn_out
+
img_attn_out
attn_out
.
add_
(
img_attn_out
)
if
self
.
clean_cuda_cache
:
del
k_img
,
v_img
,
img_attn_out
torch
.
cuda
.
empty_cache
()
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
if
self
.
clean_cuda_cache
:
del
q
,
k
,
v
,
norm3_out
,
context
,
context_img
torch
.
cuda
.
empty_cache
()
return
attn_out
return
attn_out
def
infer_
phase_4
(
self
,
weights
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
):
def
infer_
ffn
(
self
,
weights
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
):
x
.
add_
(
attn_out
)
x
.
add_
(
attn_out
)
if
self
.
clean_cuda_cache
:
del
attn_out
torch
.
cuda
.
empty_cache
()
if
hasattr
(
weights
,
"smooth_norm2_weight"
):
if
hasattr
(
weights
,
"smooth_norm2_weight"
):
norm2_weight
=
(
1
+
c_scale_msa
.
squeeze
(
0
))
*
weights
.
smooth_norm2_weight
.
tensor
norm2_weight
=
(
1
+
c_scale_msa
.
squeeze
(
0
))
*
weights
.
smooth_norm2_weight
.
tensor
norm2_bias
=
c_shift_msa
.
squeeze
(
0
)
*
weights
.
smooth_norm2_bias
.
tensor
norm2_bias
=
c_shift_msa
.
squeeze
(
0
)
*
weights
.
smooth_norm2_bias
.
tensor
...
@@ -333,21 +387,30 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -333,21 +387,30 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_weight
=
1
+
c_scale_msa
.
squeeze
(
0
)
norm2_weight
=
1
+
c_scale_msa
.
squeeze
(
0
)
norm2_bias
=
c_shift_msa
.
squeeze
(
0
)
norm2_bias
=
c_shift_msa
.
squeeze
(
0
)
norm2_out
=
weights
.
norm2
.
apply
(
x
)
x
=
weights
.
norm2
.
apply
(
x
)
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
norm2_out
=
norm2_out
.
float
()
x
=
x
.
float
()
norm2_out
=
norm2_out
*
norm2_weight
+
norm2_bias
x
.
mul_
(
norm2_weight
).
add_
(
norm2_bias
)
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
norm2_out
=
norm2_out
.
to
(
torch
.
bfloat16
)
x
=
x
.
to
(
torch
.
bfloat16
)
y
=
weights
.
ffn_0
.
apply
(
norm2_out
)
x
=
weights
.
ffn_0
.
apply
(
x
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
if
self
.
clean_cuda_cache
:
y
=
weights
.
ffn_2
.
apply
(
y
)
torch
.
cuda
.
empty_cache
()
return
y
x
=
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
)
if
self
.
clean_cuda_cache
:
torch
.
cuda
.
empty_cache
()
x
=
weights
.
ffn_2
.
apply
(
x
)
def
infer_phase_5
(
self
,
x
,
y
,
c_gate_msa
):
return
x
def
post_process
(
self
,
x
,
y
,
c_gate_msa
):
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
x
=
x
.
float
()
+
y
.
float
()
*
c_gate_msa
.
squeeze
(
0
)
x
=
x
.
float
()
+
y
.
float
()
*
c_gate_msa
.
squeeze
(
0
)
else
:
else
:
x
.
add_
(
y
*
c_gate_msa
.
squeeze
(
0
))
x
.
add_
(
y
*
c_gate_msa
.
squeeze
(
0
))
if
self
.
clean_cuda_cache
:
del
y
,
c_gate_msa
torch
.
cuda
.
empty_cache
()
return
x
return
x
lightx2v/models/networks/wan/infer/utils.py
View file @
ad051778
...
@@ -75,6 +75,39 @@ def apply_rotary_emb(x, freqs_i):
...
@@ -75,6 +75,39 @@ def apply_rotary_emb(x, freqs_i):
return
x_i
.
to
(
torch
.
bfloat16
)
return
x_i
.
to
(
torch
.
bfloat16
)
def
apply_rotary_emb_chunk
(
x
,
freqs_i
,
chunk_size
=
100
,
remaining_chunk_size
=
100
):
n
=
x
.
size
(
1
)
seq_len
=
freqs_i
.
size
(
0
)
output_chunks
=
[]
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
end
=
min
(
start
+
chunk_size
,
seq_len
)
x_chunk
=
x
[
start
:
end
]
freqs_chunk
=
freqs_i
[
start
:
end
]
x_chunk_complex
=
torch
.
view_as_complex
(
x_chunk
.
to
(
torch
.
float32
).
reshape
(
end
-
start
,
n
,
-
1
,
2
))
x_chunk_embedded
=
torch
.
view_as_real
(
x_chunk_complex
*
freqs_chunk
).
flatten
(
2
).
to
(
torch
.
bfloat16
)
output_chunks
.
append
(
x_chunk_embedded
)
del
x_chunk_complex
,
x_chunk_embedded
torch
.
cuda
.
empty_cache
()
result
=
[]
for
chunk
in
output_chunks
:
result
.
append
(
chunk
)
del
output_chunks
torch
.
cuda
.
empty_cache
()
for
start
in
range
(
seq_len
,
x
.
size
(
0
),
remaining_chunk_size
):
end
=
min
(
start
+
remaining_chunk_size
,
x
.
size
(
0
))
result
.
append
(
x
[
start
:
end
])
x_i
=
torch
.
cat
(
result
,
dim
=
0
)
del
result
torch
.
cuda
.
empty_cache
()
return
x_i
.
to
(
torch
.
bfloat16
)
def
rope_params
(
max_seq_len
,
dim
,
theta
=
10000
):
def
rope_params
(
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
assert
dim
%
2
==
0
freqs
=
torch
.
outer
(
freqs
=
torch
.
outer
(
...
...
lightx2v/models/networks/wan/model.py
View file @
ad051778
...
@@ -34,6 +34,7 @@ class WanModel:
...
@@ -34,6 +34,7 @@ class WanModel:
def
__init__
(
self
,
model_path
,
config
,
device
):
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
config
=
config
self
.
config
=
config
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized_ckpt
=
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
)
self
.
dit_quantized_ckpt
=
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
)
...
@@ -133,22 +134,7 @@ class WanModel:
...
@@ -133,22 +134,7 @@ class WanModel:
else
:
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
safetensors_pattern
=
os
.
path
.
join
(
lazy_load_model_path
,
"block_*.safetensors"
)
return
pre_post_weight_dict
safetensors_files
=
glob
.
glob
(
safetensors_pattern
)
if
not
safetensors_files
:
raise
FileNotFoundError
(
f
"No .safetensors files found in directory:
{
lazy_load_model_path
}
"
)
for
file_path
in
safetensors_files
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
for
k
in
f
.
keys
():
if
"modulation"
in
k
:
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
use_bf16
or
all
(
s
not
in
k
for
s
in
skip_bf16
):
transformer_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
else
:
transformer_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
return
pre_post_weight_dict
,
transformer_weight_dict
def
_init_weights
(
self
,
weight_dict
=
None
):
def
_init_weights
(
self
,
weight_dict
=
None
):
use_bf16
=
GET_DTYPE
()
==
"BF16"
use_bf16
=
GET_DTYPE
()
==
"BF16"
...
@@ -161,10 +147,7 @@ class WanModel:
...
@@ -161,10 +147,7 @@ class WanModel:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
original_weight_dict
=
self
.
_load_quant_ckpt
(
use_bf16
,
skip_bf16
)
self
.
original_weight_dict
=
self
.
_load_quant_ckpt
(
use_bf16
,
skip_bf16
)
else
:
else
:
(
self
.
original_weight_dict
=
self
.
_load_quant_split_ckpt
(
use_bf16
,
skip_bf16
)
self
.
original_weight_dict
,
self
.
transformer_weight_dict
,
)
=
self
.
_load_quant_split_ckpt
(
use_bf16
,
skip_bf16
)
else
:
else
:
self
.
original_weight_dict
=
weight_dict
self
.
original_weight_dict
=
weight_dict
# init weights
# init weights
...
@@ -174,10 +157,7 @@ class WanModel:
...
@@ -174,10 +157,7 @@ class WanModel:
# load weights
# load weights
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
if
hasattr
(
self
,
"transformer_weight_dict"
):
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
transformer_weight_dict
)
else
:
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
def
_init_infer
(
self
):
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
...
@@ -212,13 +192,21 @@ class WanModel:
...
@@ -212,13 +192,21 @@ class WanModel:
self
.
scheduler
.
noise_pred
=
noise_pred_cond
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_cond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"enable_cfg"
]:
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred
_cond
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
ad051778
...
@@ -34,13 +34,8 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -34,13 +34,8 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
config
=
config
self
.
config
=
config
self
.
quant_method
=
config
[
"mm_config"
].
get
(
"quant_method"
,
None
)
self
.
quant_method
=
config
[
"mm_config"
].
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
register_parameter
(
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.modulation"
),
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
if
self
.
lazy_load
:
lazy_load_path
=
os
.
path
.
join
(
self
.
config
.
dit_quantized_ckpt
,
f
"block_
{
block_index
}
.safetensors"
)
lazy_load_path
=
os
.
path
.
join
(
self
.
config
.
dit_quantized_ckpt
,
f
"block_
{
block_index
}
.safetensors"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
...
@@ -49,6 +44,14 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -49,6 +44,14 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
compute_phases
=
WeightModuleList
(
self
.
compute_phases
=
WeightModuleList
(
[
[
WanModulation
(
block_index
,
task
,
mm_type
,
config
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
WanSelfAttention
(
WanSelfAttention
(
block_index
,
block_index
,
task
,
task
,
...
@@ -79,6 +82,29 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -79,6 +82,29 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
add_module
(
"compute_phases"
,
self
.
compute_phases
)
self
.
add_module
(
"compute_phases"
,
self
.
compute_phases
)
class
WanModulation
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
[
"mm_config"
].
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
add_module
(
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.modulation"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
class
WanSelfAttention
(
WeightModule
):
class
WanSelfAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
super
().
__init__
()
...
@@ -92,7 +118,7 @@ class WanSelfAttention(WeightModule):
...
@@ -92,7 +118,7 @@ class WanSelfAttention(WeightModule):
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
lazy_load_file
=
lazy_load_file
self
.
register_parameter
(
self
.
add_module
(
"norm1"
,
"norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
LN_WEIGHT_REGISTER
[
"Default"
](),
)
)
...
@@ -160,7 +186,7 @@ class WanSelfAttention(WeightModule):
...
@@ -160,7 +186,7 @@ class WanSelfAttention(WeightModule):
else
:
else
:
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
self
.
register_parameter
(
self
.
add_module
(
"smooth_norm1_weight"
,
"smooth_norm1_weight"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm1.weight"
,
f
"blocks.
{
self
.
block_index
}
.affine_norm1.weight"
,
...
@@ -168,7 +194,7 @@ class WanSelfAttention(WeightModule):
...
@@ -168,7 +194,7 @@ class WanSelfAttention(WeightModule):
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
register_parameter
(
self
.
add_module
(
"smooth_norm1_bias"
,
"smooth_norm1_bias"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm1.bias"
,
f
"blocks.
{
self
.
block_index
}
.affine_norm1.bias"
,
...
@@ -292,7 +318,7 @@ class WanFFN(WeightModule):
...
@@ -292,7 +318,7 @@ class WanFFN(WeightModule):
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
lazy_load_file
=
lazy_load_file
self
.
register_parameter
(
self
.
add_module
(
"norm2"
,
"norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
LN_WEIGHT_REGISTER
[
"Default"
](),
)
)
...
@@ -317,7 +343,7 @@ class WanFFN(WeightModule):
...
@@ -317,7 +343,7 @@ class WanFFN(WeightModule):
)
)
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
self
.
register_parameter
(
self
.
add_module
(
"smooth_norm2_weight"
,
"smooth_norm2_weight"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm3.weight"
,
f
"blocks.
{
self
.
block_index
}
.affine_norm3.weight"
,
...
@@ -325,7 +351,7 @@ class WanFFN(WeightModule):
...
@@ -325,7 +351,7 @@ class WanFFN(WeightModule):
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
register_parameter
(
self
.
add_module
(
"smooth_norm2_bias"
,
"smooth_norm2_bias"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm3.bias"
,
f
"blocks.
{
self
.
block_index
}
.affine_norm3.bias"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment