Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
a395cc0a
Commit
a395cc0a
authored
Jul 31, 2025
by
gushiqiao
Committed by
GitHub
Jul 31, 2025
Browse files
Fix offload bugs and support wan2.2_vae offload
Fix offload bugs and support wan2.2_vae offload
parents
b723fc89
aa88b371
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
34 deletions
+61
-34
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+26
-9
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-1
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+34
-24
No files found.
lightx2v/models/networks/wan/model.py
View file @
a395cc0a
...
@@ -37,6 +37,9 @@ class WanModel:
...
@@ -37,6 +37,9 @@ class WanModel:
def
__init__
(
self
,
model_path
,
config
,
device
):
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
config
=
config
self
.
config
=
config
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
...
@@ -202,15 +205,18 @@ class WanModel:
...
@@ -202,15 +205,18 @@ class WanModel:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
@@ -228,14 +234,17 @@ class WanModel:
...
@@ -228,14 +234,17 @@ class WanModel:
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
class
Wan22MoeModel
(
WanModel
):
class
Wan22MoeModel
(
WanModel
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
...
@@ -248,6 +257,10 @@ class Wan22MoeModel(WanModel):
...
@@ -248,6 +257,10 @@ class Wan22MoeModel(WanModel):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
and
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
@@ -260,3 +273,7 @@ class Wan22MoeModel(WanModel):
...
@@ -260,3 +273,7 @@ class Wan22MoeModel(WanModel):
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
cpu_offload
and
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
lightx2v/models/runners/wan/wan_runner.py
View file @
a395cc0a
...
@@ -418,5 +418,5 @@ class Wan22DenseRunner(WanRunner):
...
@@ -418,5 +418,5 @@ class Wan22DenseRunner(WanRunner):
return
vae_encoder_out
return
vae_encoder_out
def
get_vae_encoder_output
(
self
,
img
):
def
get_vae_encoder_output
(
self
,
img
):
z
=
self
.
vae_encoder
.
encode
(
img
)
z
=
self
.
vae_encoder
.
encode
(
img
,
self
.
config
)
return
z
return
z
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
a395cc0a
...
@@ -844,7 +844,7 @@ class Wan2_2_VAE:
...
@@ -844,7 +844,7 @@ class Wan2_2_VAE:
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
mean
=
torch
.
tensor
(
self
.
mean
=
torch
.
tensor
(
[
[
-
0.2289
,
-
0.2289
,
-
0.0052
,
-
0.0052
,
...
@@ -898,7 +898,7 @@ class Wan2_2_VAE:
...
@@ -898,7 +898,7 @@ class Wan2_2_VAE:
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
std
=
torch
.
tensor
(
self
.
std
=
torch
.
tensor
(
[
[
0.4765
,
0.4765
,
1.0364
,
1.0364
,
...
@@ -952,8 +952,8 @@ class Wan2_2_VAE:
...
@@ -952,8 +952,8 @@ class Wan2_2_VAE:
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
self
.
scale
=
[
mean
,
1.0
/
std
]
self
.
inv_std
=
1.0
/
self
.
std
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# init model
# init model
self
.
model
=
(
self
.
model
=
(
_video_vae
(
_video_vae
(
...
@@ -968,25 +968,35 @@ class Wan2_2_VAE:
...
@@ -968,25 +968,35 @@ class Wan2_2_VAE:
.
to
(
device
)
.
to
(
device
)
)
)
def
encode
(
self
,
videos
):
def
to_cpu
(
self
):
# try:
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
"cpu"
)
# if not isinstance(videos, list):
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
"cpu"
)
# raise TypeError("videos should be a list")
self
.
model
=
self
.
model
.
to
(
"cpu"
)
# with amp.autocast(dtype=self.dtype):
self
.
mean
=
self
.
mean
.
cpu
()
# return [
self
.
inv_std
=
self
.
inv_std
.
cpu
()
# self.model.encode(u.unsqueeze(0),
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# self.scale).float().squeeze(0)
# for u in videos
def
to_cuda
(
self
):
# ]
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
"cuda"
)
# except TypeError as e:
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
"cuda"
)
# logging.info(e)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
# return None
self
.
mean
=
self
.
mean
.
cuda
()
self
.
inv_std
=
self
.
inv_std
.
cuda
()
# print(1111111)
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# print(self.model.encode(videos.unsqueeze(0), self.scale).float().shape)
# exit()
def
encode
(
self
,
videos
,
args
):
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
return
self
.
model
.
encode
(
videos
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
self
.
to_cuda
()
out
=
self
.
model
.
encode
(
videos
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cpu
()
return
out
def
decode
(
self
,
zs
,
generator
,
config
):
def
decode
(
self
,
zs
,
generator
,
config
):
return
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
self
.
to_cuda
()
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
self
.
to_cpu
()
return
images
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment