Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
aa88b371
Commit
aa88b371
authored
Jul 31, 2025
by
gushiqiao
Browse files
Fix offload bugs and support wan2.2_vae offload
parent
6277a533
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
34 deletions
+61
-34
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+26
-9
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-1
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+34
-24
No files found.
lightx2v/models/networks/wan/model.py
View file @
aa88b371
...
@@ -37,6 +37,9 @@ class WanModel:
...
@@ -37,6 +37,9 @@ class WanModel:
def
__init__
(
self
,
model_path
,
config
,
device
):
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
config
=
config
self
.
config
=
config
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
...
@@ -202,15 +205,18 @@ class WanModel:
...
@@ -202,15 +205,18 @@ class WanModel:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
@@ -228,14 +234,17 @@ class WanModel:
...
@@ -228,14 +234,17 @@ class WanModel:
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
class
Wan22MoeModel
(
WanModel
):
class
Wan22MoeModel
(
WanModel
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
...
@@ -248,6 +257,10 @@ class Wan22MoeModel(WanModel):
...
@@ -248,6 +257,10 @@ class Wan22MoeModel(WanModel):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
and
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
@@ -260,3 +273,7 @@ class Wan22MoeModel(WanModel):
...
@@ -260,3 +273,7 @@ class Wan22MoeModel(WanModel):
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
cpu_offload
and
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
lightx2v/models/runners/wan/wan_runner.py
View file @
aa88b371
...
@@ -418,5 +418,5 @@ class Wan22DenseRunner(WanRunner):
...
@@ -418,5 +418,5 @@ class Wan22DenseRunner(WanRunner):
return
vae_encoder_out
return
vae_encoder_out
def
get_vae_encoder_output
(
self
,
img
):
def
get_vae_encoder_output
(
self
,
img
):
z
=
self
.
vae_encoder
.
encode
(
img
)
z
=
self
.
vae_encoder
.
encode
(
img
,
self
.
config
)
return
z
return
z
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
aa88b371
...
@@ -844,7 +844,7 @@ class Wan2_2_VAE:
...
@@ -844,7 +844,7 @@ class Wan2_2_VAE:
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
mean
=
torch
.
tensor
(
self
.
mean
=
torch
.
tensor
(
[
[
-
0.2289
,
-
0.2289
,
-
0.0052
,
-
0.0052
,
...
@@ -898,7 +898,7 @@ class Wan2_2_VAE:
...
@@ -898,7 +898,7 @@ class Wan2_2_VAE:
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
std
=
torch
.
tensor
(
self
.
std
=
torch
.
tensor
(
[
[
0.4765
,
0.4765
,
1.0364
,
1.0364
,
...
@@ -952,8 +952,8 @@ class Wan2_2_VAE:
...
@@ -952,8 +952,8 @@ class Wan2_2_VAE:
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
self
.
scale
=
[
mean
,
1.0
/
std
]
self
.
inv_std
=
1.0
/
self
.
std
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# init model
# init model
self
.
model
=
(
self
.
model
=
(
_video_vae
(
_video_vae
(
...
@@ -968,25 +968,35 @@ class Wan2_2_VAE:
...
@@ -968,25 +968,35 @@ class Wan2_2_VAE:
.
to
(
device
)
.
to
(
device
)
)
)
def
encode
(
self
,
videos
):
def
to_cpu
(
self
):
# try:
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
"cpu"
)
# if not isinstance(videos, list):
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
"cpu"
)
# raise TypeError("videos should be a list")
self
.
model
=
self
.
model
.
to
(
"cpu"
)
# with amp.autocast(dtype=self.dtype):
self
.
mean
=
self
.
mean
.
cpu
()
# return [
self
.
inv_std
=
self
.
inv_std
.
cpu
()
# self.model.encode(u.unsqueeze(0),
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# self.scale).float().squeeze(0)
# for u in videos
def
to_cuda
(
self
):
# ]
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
"cuda"
)
# except TypeError as e:
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
"cuda"
)
# logging.info(e)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
# return None
self
.
mean
=
self
.
mean
.
cuda
()
self
.
inv_std
=
self
.
inv_std
.
cuda
()
# print(1111111)
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# print(self.model.encode(videos.unsqueeze(0), self.scale).float().shape)
# exit()
def
encode
(
self
,
videos
,
args
):
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
return
self
.
model
.
encode
(
videos
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
self
.
to_cuda
()
out
=
self
.
model
.
encode
(
videos
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cpu
()
return
out
def
decode
(
self
,
zs
,
generator
,
config
):
def
decode
(
self
,
zs
,
generator
,
config
):
return
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
self
.
to_cuda
()
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
self
.
to_cpu
()
return
images
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment