Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
fba9754a
Commit
fba9754a
authored
Aug 20, 2025
by
gushiqiao
Committed by
GitHub
Aug 20, 2025
Browse files
[Reconstruct] recon infer class (#228)
parent
d8a2731b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
271 additions
and
279 deletions
+271
-279
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+4
-2
lightx2v/models/networks/wan/infer/audio/post_infer.py
lightx2v/models/networks/wan/infer/audio/post_infer.py
+0
-0
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+0
-0
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
...tx2v/models/networks/wan/infer/audio/transformer_infer.py
+25
-0
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+2
-3
lightx2v/models/networks/wan/infer/offload/__init__.py
lightx2v/models/networks/wan/infer/offload/__init__.py
+0
-0
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+201
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+17
-268
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+20
-4
No files found.
lightx2v/models/networks/wan/audio_model.py
100644 → 100755
View file @
fba9754a
import
glob
import
glob
import
os
import
os
from
lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.infer.audio.pre_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.infer.audio.transformer_infer
import
WanAudioTransformerInfer
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
...
@@ -23,6 +24,7 @@ class WanAudioModel(WanModel):
...
@@ -23,6 +24,7 @@ class WanAudioModel(WanModel):
super
().
_init_infer_class
()
super
().
_init_infer_class
()
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
post_infer_class
=
WanAudioPostInfer
self
.
post_infer_class
=
WanAudioPostInfer
self
.
transformer_infer_class
=
WanAudioTransformerInfer
class
Wan22MoeAudioModel
(
WanAudioModel
):
class
Wan22MoeAudioModel
(
WanAudioModel
):
...
...
lightx2v/models/networks/wan/infer/audio/post_
wan_audio_
infer.py
→
lightx2v/models/networks/wan/infer/audio/post_infer.py
View file @
fba9754a
File moved
lightx2v/models/networks/wan/infer/audio/pre_
wan_audio_
infer.py
→
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
fba9754a
File moved
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
0 → 100644
View file @
fba9754a
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
from
lightx2v.models.networks.wan.infer.utils
import
compute_freqs_audio
,
compute_freqs_audio_dist
class
WanAudioTransformerInfer
(
WanOffloadTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
self
.
config
[
"seq_parallel"
]:
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
# Apply audio_dit if available
if
pre_infer_out
.
audio_dit_blocks
is
not
None
and
hasattr
(
self
,
"block_idx"
):
for
ipa_out
in
pre_infer_out
.
audio_dit_blocks
:
if
self
.
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
self
.
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
pre_infer_out
.
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
fba9754a
...
@@ -2,13 +2,13 @@ import math
...
@@ -2,13 +2,13 @@ import math
import
torch
import
torch
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..transformer_infer
import
WanTransformerInfer
from
..utils
import
apply_rotary_emb
,
compute_freqs_causvid
from
..utils
import
apply_rotary_emb
,
compute_freqs_causvid
class
WanTransformerInferCausVid
(
WanTransformerInfer
):
class
WanTransformerInferCausVid
(
Wan
Offload
TransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
num_frames
=
config
[
"num_frames"
]
self
.
num_frames
=
config
[
"num_frames"
]
...
...
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
fba9754a
...
@@ -6,11 +6,10 @@ import torch
...
@@ -6,11 +6,10 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTaylorCachingTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTaylorCachingTransformerInfer
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
from
..transformer_infer
import
WanTransformerInfer
class
WanTransformerInferCaching
(
WanOffloadTransformerInfer
):
class
WanTransformerInferCaching
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
must_calc_steps
=
[]
self
.
must_calc_steps
=
[]
...
...
lightx2v/models/networks/wan/infer/offload/__init__.py
0 → 100755
View file @
fba9754a
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
0 → 100644
View file @
fba9754a
import
torch
from
lightx2v.common.offload.manager
import
(
LazyWeightAsyncStreamManager
,
WeightAsyncStreamManager
,
)
from
..transformer_infer
import
WanTransformerInfer
class
WanOffloadTransformerInfer
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
offload_ratio
=
1
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
if
offload_granularity
==
"block"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
infer_with_offload
else
:
self
.
infer_func
=
self
.
infer_with_lazy_offload
elif
offload_granularity
==
"phase"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
infer_with_phases_offload
else
:
self
.
infer_func
=
self
.
infer_with_phases_lazy_offload
elif
offload_granularity
==
"model"
:
self
.
infer_func
=
self
.
_infer_without_offload
if
offload_granularity
!=
"model"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
phases_num
=
self
.
phases_num
,
)
else
:
self
.
weights_stream_mgr
=
LazyWeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
phases_num
=
self
.
phases_num
,
num_disk_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
2
),
max_memory
=
self
.
config
.
get
(
"max_memory"
,
2
),
offload_gra
=
offload_granularity
,
)
def
infer_with_offload
(
self
,
weights
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
self
.
blocks_num
):
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
if
block_idx
<
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
x
,
pre_infer_out
)
self
.
weights_stream_mgr
.
swap_weights
()
return
x
def
infer_with_lazy_offload
(
self
,
weights
,
x
,
pre_infer_out
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
block
=
self
.
weights_stream_mgr
.
pin_memory_buffer
.
get
(
block_idx
)
block
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
block_idx
,
block
)
if
block_idx
<
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
x
,
pre_infer_out
)
self
.
weights_stream_mgr
.
swap_weights
()
if
block_idx
==
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
del
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
torch
.
cuda
.
empty_cache
()
return
x
def
infer_with_phases_offload
(
self
,
weights
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
weights
.
blocks_num
):
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
phase_idx
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
cur_phase_idx
,
cur_phase
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
cur_phase
,
pre_infer_out
.
embed0
)
elif
cur_phase_idx
==
1
:
y_out
=
self
.
infer_self_attn
(
cur_phase
,
pre_infer_out
.
grid_sizes
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
shift_msa
,
scale_msa
,
)
elif
cur_phase_idx
==
2
:
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
pre_infer_out
.
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
weights_stream_mgr
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_phases
()
if
self
.
clean_cuda_cache
:
del
attn_out
,
y_out
,
y
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
del
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
del
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
torch
.
cuda
.
empty_cache
()
return
x
def
infer_with_phases_lazy_offload
(
self
,
weights
,
x
,
pre_infer_out
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
weights
.
blocks_num
):
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
obj_key
=
(
block_idx
,
phase_idx
)
phase
=
self
.
weights_stream_mgr
.
pin_memory_buffer
.
get
(
obj_key
)
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
obj_key
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
(
(
_
,
cur_phase_idx
,
),
cur_phase
,
)
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
cur_phase
,
pre_infer_out
.
embed0
)
elif
cur_phase_idx
==
1
:
y_out
=
self
.
infer_self_attn
(
cur_phase
,
pre_infer_out
.
grid_sizes
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
shift_msa
,
scale_msa
,
)
elif
cur_phase_idx
==
2
:
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
pre_infer_out
.
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
weights_stream_mgr
.
phases_num
self
.
weights_stream_mgr
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
del
attn_out
,
y_out
,
y
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
del
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
del
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
torch
.
cuda
.
empty_cache
()
return
x
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
fba9754a
...
@@ -2,14 +2,10 @@ from functools import partial
...
@@ -2,14 +2,10 @@ from functools import partial
import
torch
import
torch
from
lightx2v.common.offload.manager
import
(
LazyWeightAsyncStreamManager
,
WeightAsyncStreamManager
,
)
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
,
compute_freqs_audio_dist
,
compute_freqs_dist
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_dist
class
WanTransformerInfer
(
BaseTransformerInfer
):
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
@@ -37,46 +33,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -37,46 +33,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
else
:
self
.
seq_p_group
=
None
self
.
seq_p_group
=
None
self
.
infer_func
=
self
.
infer_without_offload
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
# if torch.cuda.get_device_capability(0) == (9, 0):
# assert self.config["self_attn_1_type"] != "sage_attn2"
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
offload_ratio
=
1
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
if
offload_granularity
==
"block"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
_infer_with_offload
else
:
self
.
infer_func
=
self
.
_infer_with_lazy_offload
elif
offload_granularity
==
"phase"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
_infer_with_phases_offload
else
:
self
.
infer_func
=
self
.
_infer_with_phases_lazy_offload
elif
offload_granularity
==
"model"
:
self
.
infer_func
=
self
.
_infer_without_offload
if
offload_granularity
!=
"model"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
phases_num
=
self
.
phases_num
,
)
else
:
self
.
weights_stream_mgr
=
LazyWeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
phases_num
=
self
.
phases_num
,
num_disk_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
2
),
max_memory
=
self
.
config
.
get
(
"max_memory"
,
2
),
offload_gra
=
offload_granularity
,
)
else
:
self
.
infer_func
=
self
.
_infer_without_offload
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
...
@@ -86,36 +43,20 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -86,36 +43,20 @@ class WanTransformerInfer(BaseTransformerInfer):
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
else
:
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
return
freqs_i
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_main_blocks
(
weights
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
,
pre_infer_out
)
return
self
.
infer_
post
_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
return
self
.
infer_
non
_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
def
infer_main_blocks
(
self
,
weights
,
pre_infer_out
):
def
infer_main_blocks
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_func
(
x
=
self
.
infer_func
(
weights
,
pre_infer_out
.
x
,
pre_infer_out
)
weights
,
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed
,
pre_infer_out
.
x
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
pre_infer_out
.
audio_dit_blocks
,
)
return
x
return
x
def
infer_
post
_blocks
(
self
,
weights
,
x
,
e
):
def
infer_
non
_blocks
(
self
,
weights
,
x
,
e
):
if
e
.
dim
()
==
2
:
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
...
@@ -139,214 +80,29 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -139,214 +80,29 @@ class WanTransformerInfer(BaseTransformerInfer):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
x
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
infer_without_offload
(
self
,
weights
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
self
.
blocks_num
):
self
.
block_idx
=
block_idx
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
return
x
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
x
,
pre_infer_out
)
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
if
block_idx
<
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_block
(
self
.
weights_stream_mgr
.
active_weights
[
0
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
self
.
weights_stream_mgr
.
swap_weights
()
return
x
def
_infer_with_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
block
=
self
.
weights_stream_mgr
.
pin_memory_buffer
.
get
(
block_idx
)
block
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
block_idx
,
block
)
if
block_idx
<
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_block
(
self
.
weights_stream_mgr
.
active_weights
[
0
][
1
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
,
)
self
.
weights_stream_mgr
.
swap_weights
()
if
block_idx
==
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
torch
.
cuda
.
empty_cache
()
return
x
return
x
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
infer_block
(
self
,
weights
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
weights
.
blocks_num
):
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
phase_idx
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
cur_phase_idx
,
cur_phase
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
cur_phase
,
embed0
)
elif
cur_phase_idx
==
1
:
y_out
=
self
.
infer_self_attn
(
cur_phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
,
)
elif
cur_phase_idx
==
2
:
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
weights_stream_mgr
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_phases
()
if
self
.
clean_cuda_cache
:
del
attn_out
,
y_out
,
y
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
del
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
torch
.
cuda
.
empty_cache
()
return
x
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
weights
.
blocks_num
):
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
obj_key
=
(
block_idx
,
phase_idx
)
phase
=
self
.
weights_stream_mgr
.
pin_memory_buffer
.
get
(
obj_key
)
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
obj_key
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
(
(
_
,
cur_phase_idx
,
),
cur_phase
,
)
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
cur_phase
,
embed0
,
)
elif
cur_phase_idx
==
1
:
y_out
=
self
.
infer_self_attn
(
cur_phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
,
)
elif
cur_phase_idx
==
2
:
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
weights_stream_mgr
.
phases_num
self
.
weights_stream_mgr
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
del
attn_out
,
y_out
,
y
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
del
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
torch
.
cuda
.
empty_cache
()
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
weights
.
compute_phases
[
0
],
weights
.
compute_phases
[
0
],
embed0
,
pre_infer_out
.
embed0
,
)
)
y_out
=
self
.
infer_self_attn
(
y_out
=
self
.
infer_self_attn
(
weights
.
compute_phases
[
1
],
weights
.
compute_phases
[
1
],
grid_sizes
,
pre_infer_out
.
grid_sizes
,
x
,
x
,
seq_lens
,
pre_infer_out
.
seq_lens
,
freqs
,
pre_infer_out
.
freqs
,
shift_msa
,
shift_msa
,
scale_msa
,
scale_msa
,
)
)
x
,
attn_out
=
self
.
infer_cross_attn
(
weights
.
compute_phases
[
2
],
x
,
context
,
y_out
,
gate_msa
)
x
,
attn_out
=
self
.
infer_cross_attn
(
weights
.
compute_phases
[
2
],
x
,
pre_infer_out
.
context
,
y_out
,
gate_msa
)
y
=
self
.
infer_ffn
(
weights
.
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
y
=
self
.
infer_ffn
(
weights
.
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
return
x
return
x
def
infer_modulation
(
self
,
weights
,
embed0
):
def
infer_modulation
(
self
,
weights
,
embed0
):
...
@@ -531,19 +287,12 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -531,19 +287,12 @@ class WanTransformerInfer(BaseTransformerInfer):
return
y
return
y
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
grid_sizes
,
audio_dit_blocks
=
None
):
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
+
y
.
to
(
self
.
sensitive_layer_dtype
)
*
c_gate_msa
.
squeeze
()
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
+
y
.
to
(
self
.
sensitive_layer_dtype
)
*
c_gate_msa
.
squeeze
()
else
:
else
:
x
.
add_
(
y
*
c_gate_msa
.
squeeze
())
x
.
add_
(
y
*
c_gate_msa
.
squeeze
())
# Apply audio_dit if available
if
audio_dit_blocks
is
not
None
and
hasattr
(
self
,
"block_idx"
):
for
ipa_out
in
audio_dit_blocks
:
if
self
.
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
self
.
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
y
,
c_gate_msa
del
y
,
c_gate_msa
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/networks/wan/model.py
View file @
fba9754a
...
@@ -18,6 +18,9 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
...
@@ -18,6 +18,9 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
WanTransformerInferTaylorCaching
,
WanTransformerInferTaylorCaching
,
WanTransformerInferTeaCaching
,
WanTransformerInferTeaCaching
,
)
)
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
(
WanOffloadTransformerInfer
,
)
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.transformer_infer
import
(
from
lightx2v.models.networks.wan.infer.transformer_infer
import
(
...
@@ -64,7 +67,12 @@ class WanModel:
...
@@ -64,7 +67,12 @@ class WanModel:
self
.
dit_quantized_ckpt
=
find_gguf_model_path
(
config
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
dit_quantized_ckpt
=
find_gguf_model_path
(
config
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
config
.
use_gguf
=
True
self
.
config
.
use_gguf
=
True
else
:
else
:
self
.
dit_quantized_ckpt
=
find_hf_model_path
(
config
,
self
.
model_path
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
dit_quantized_ckpt
=
find_hf_model_path
(
config
,
self
.
model_path
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
,
)
quant_config_path
=
os
.
path
.
join
(
self
.
dit_quantized_ckpt
,
"config.json"
)
quant_config_path
=
os
.
path
.
join
(
self
.
dit_quantized_ckpt
,
"config.json"
)
if
os
.
path
.
exists
(
quant_config_path
):
if
os
.
path
.
exists
(
quant_config_path
):
with
open
(
quant_config_path
,
"r"
)
as
f
:
with
open
(
quant_config_path
,
"r"
)
as
f
:
...
@@ -90,7 +98,7 @@ class WanModel:
...
@@ -90,7 +98,7 @@ class WanModel:
self
.
post_infer_class
=
WanPostInfer
self
.
post_infer_class
=
WanPostInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
self
.
transformer_infer_class
=
WanTransformerInfer
if
not
self
.
cpu_offload
else
WanOffloadTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
...
@@ -158,7 +166,11 @@ class WanModel:
...
@@ -158,7 +166,11 @@ class WanModel:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
)
as
f
:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
]:
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
,
]:
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
else
:
else
:
...
@@ -176,7 +188,11 @@ class WanModel:
...
@@ -176,7 +188,11 @@ class WanModel:
safetensor_path
=
os
.
path
.
join
(
lazy_load_model_path
,
"non_block.safetensors"
)
safetensor_path
=
os
.
path
.
join
(
lazy_load_model_path
,
"non_block.safetensors"
)
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
]:
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
,
]:
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment