Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ff66b814
Commit
ff66b814
authored
Aug 15, 2025
by
gushiqiao
Browse files
Fix offload bug in new dist infer
parent
cc04b3fb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
11 deletions
+17
-11
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
...v/models/networks/wan/infer/audio/post_wan_audio_infer.py
+1
-1
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+1
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+3
-7
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+10
-0
No files found.
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
View file @
ff66b814
...
@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer):
...
@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:,
:
pre_infer_out
.
valid_patch_length
]
x
=
x
[:,
:
pre_infer_out
.
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
...
...
lightx2v/models/networks/wan/infer/post_infer.py
View file @
ff66b814
...
@@ -15,7 +15,7 @@ class WanPostInfer:
...
@@ -15,7 +15,7 @@ class WanPostInfer:
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
ff66b814
...
@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
seq_p_group
=
None
self
.
seq_p_group
=
None
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
#
if torch.cuda.get_device_capability(0) == (9, 0):
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
#
assert self.config["self_attn_1_type"] != "sage_attn2"
if
"offload_ratio"
in
self
.
config
:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
else
:
...
...
lightx2v/models/networks/wan/model.py
View file @
ff66b814
...
@@ -225,12 +225,10 @@ class WanModel:
...
@@ -225,12 +225,10 @@ class WanModel:
# Initialize weight containers
# Initialize weight containers
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# Load weights into containers
# Load weights into containers
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
...
@@ -303,12 +301,10 @@ class WanModel:
...
@@ -303,12 +301,10 @@ class WanModel:
def
to_cpu
(
self
):
def
to_cpu
(
self
):
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
to_cpu
()
self
.
transformer_weights
.
to_cpu
()
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -318,7 +314,7 @@ class WanModel:
...
@@ -318,7 +314,7 @@ class WanModel:
self
.
to_cuda
()
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
post_weight
s_
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
...
@@ -356,7 +352,7 @@ class WanModel:
...
@@ -356,7 +352,7 @@ class WanModel:
self
.
to_cpu
()
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
post_weight
s_
to_cpu
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
def
_infer_cond_uncond
(
self
,
inputs
,
positive
=
True
):
...
@@ -370,7 +366,7 @@ class WanModel:
...
@@ -370,7 +366,7 @@ class WanModel:
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
x
=
self
.
_seq_parallel_post_process
(
x
)
x
=
self
.
_seq_parallel_post_process
(
x
)
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
pre_infer_out
)[
0
]
noise_pred
=
self
.
post_infer
.
infer
(
x
,
pre_infer_out
)[
0
]
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
x
,
pre_infer_out
del
x
,
pre_infer_out
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
ff66b814
...
@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule):
...
@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule):
for
phase
in
block
.
compute_phases
:
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
phase
.
clear
()
def
post_weights_to_cuda
(
self
):
self
.
norm
.
to_cuda
()
self
.
head
.
to_cuda
()
self
.
head_modulation
.
to_cuda
()
def
post_weights_to_cpu
(
self
):
self
.
norm
.
to_cpu
()
self
.
head
.
to_cpu
()
self
.
head_modulation
.
to_cpu
()
class
WanTransformerAttentionBlock
(
WeightModule
):
class
WanTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment