Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
0b755a97
Commit
0b755a97
authored
Aug 15, 2025
by
gushiqiao
Committed by
GitHub
Aug 15, 2025
Browse files
Fix offload bug in new dist infer
Fix offload bug in new dist infer
parents
88433448
ff66b814
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
11 deletions
+17
-11
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
...v/models/networks/wan/infer/audio/post_wan_audio_infer.py
+1
-1
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+1
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+3
-7
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+10
-0
No files found.
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
View file @
0b755a97
...
...
@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer):
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:,
:
pre_infer_out
.
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
...
...
lightx2v/models/networks/wan/infer/post_infer.py
View file @
0b755a97
...
...
@@ -15,7 +15,7 @@ class WanPostInfer:
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
if
self
.
clean_cuda_cache
:
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
0b755a97
...
...
@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
seq_p_group
=
None
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
#
if torch.cuda.get_device_capability(0) == (9, 0):
#
assert self.config["self_attn_1_type"] != "sage_attn2"
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
...
...
lightx2v/models/networks/wan/model.py
View file @
0b755a97
...
...
@@ -225,12 +225,10 @@ class WanModel:
# Initialize weight containers
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# Load weights into containers
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
...
...
@@ -303,12 +301,10 @@ class WanModel:
def
to_cpu
(
self
):
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
to_cpu
()
def
to_cuda
(
self
):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
...
...
@@ -318,7 +314,7 @@ class WanModel:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
post_weight
s_
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
...
...
@@ -356,7 +352,7 @@ class WanModel:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
post_weight
s_
to_cpu
()
@
torch
.
no_grad
()
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
...
...
@@ -370,7 +366,7 @@ class WanModel:
if
self
.
config
[
"seq_parallel"
]:
x
=
self
.
_seq_parallel_post_process
(
x
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
pre_infer_out
)[
0
]
noise_pred
=
self
.
post_infer
.
infer
(
x
,
pre_infer_out
)[
0
]
if
self
.
clean_cuda_cache
:
del
x
,
pre_infer_out
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
0b755a97
...
...
@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule):
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
def
post_weights_to_cuda
(
self
):
self
.
norm
.
to_cuda
()
self
.
head
.
to_cuda
()
self
.
head_modulation
.
to_cuda
()
def
post_weights_to_cpu
(
self
):
self
.
norm
.
to_cpu
()
self
.
head
.
to_cpu
()
self
.
head_modulation
.
to_cpu
()
class
WanTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment