Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
0b755a97
Commit
0b755a97
authored
Aug 15, 2025
by
gushiqiao
Committed by
GitHub
Aug 15, 2025
Browse files
Fix offload bug in new dist infer
Fix offload bug in new dist infer
parents
88433448
ff66b814
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
11 deletions
+17
-11
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
...v/models/networks/wan/infer/audio/post_wan_audio_infer.py
+1
-1
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+1
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+3
-7
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+10
-0
No files found.
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
View file @
0b755a97
...
@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer):
...
@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:,
:
pre_infer_out
.
valid_patch_length
]
x
=
x
[:,
:
pre_infer_out
.
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
...
...
lightx2v/models/networks/wan/infer/post_infer.py
View file @
0b755a97
...
@@ -15,7 +15,7 @@ class WanPostInfer:
...
@@ -15,7 +15,7 @@ class WanPostInfer:
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
0b755a97
...
@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
seq_p_group
=
None
self
.
seq_p_group
=
None
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
#
if torch.cuda.get_device_capability(0) == (9, 0):
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
#
assert self.config["self_attn_1_type"] != "sage_attn2"
if
"offload_ratio"
in
self
.
config
:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
else
:
...
...
lightx2v/models/networks/wan/model.py
View file @
0b755a97
...
@@ -225,12 +225,10 @@ class WanModel:
...
@@ -225,12 +225,10 @@ class WanModel:
# Initialize weight containers
# Initialize weight containers
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# Load weights into containers
# Load weights into containers
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
...
@@ -303,12 +301,10 @@ class WanModel:
...
@@ -303,12 +301,10 @@ class WanModel:
def
to_cpu
(
self
):
def
to_cpu
(
self
):
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
to_cpu
()
self
.
transformer_weights
.
to_cpu
()
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -318,7 +314,7 @@ class WanModel:
...
@@ -318,7 +314,7 @@ class WanModel:
self
.
to_cuda
()
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
post_weight
s_
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
...
@@ -356,7 +352,7 @@ class WanModel:
...
@@ -356,7 +352,7 @@ class WanModel:
self
.
to_cpu
()
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
post_weight
s_
to_cpu
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
...
@@ -370,7 +366,7 @@ class WanModel:
...
@@ -370,7 +366,7 @@ class WanModel:
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
x
=
self
.
_seq_parallel_post_process
(
x
)
x
=
self
.
_seq_parallel_post_process
(
x
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
pre_infer_out
)[
0
]
noise_pred
=
self
.
post_infer
.
infer
(
x
,
pre_infer_out
)[
0
]
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
x
,
pre_infer_out
del
x
,
pre_infer_out
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
0b755a97
...
@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule):
...
@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule):
for
phase
in
block
.
compute_phases
:
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
phase
.
clear
()
def
post_weights_to_cuda
(
self
):
self
.
norm
.
to_cuda
()
self
.
head
.
to_cuda
()
self
.
head_modulation
.
to_cuda
()
def
post_weights_to_cpu
(
self
):
self
.
norm
.
to_cpu
()
self
.
head
.
to_cpu
()
self
.
head_modulation
.
to_cpu
()
class
WanTransformerAttentionBlock
(
WeightModule
):
class
WanTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment