Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
a395cc0a
Commit
a395cc0a
authored
Jul 31, 2025
by
gushiqiao
Committed by
GitHub
Jul 31, 2025
Browse files
Fix offload bugs and support wan2.2_vae offload
Fix offload bugs and support wan2.2_vae offload
parents
b723fc89
aa88b371
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
34 deletions
+61
-34
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+26
-9
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-1
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+34
-24
No files found.
lightx2v/models/networks/wan/model.py
View file @
a395cc0a
...
...
@@ -37,6 +37,9 @@ class WanModel:
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
model_path
=
model_path
self
.
config
=
config
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
...
...
@@ -202,15 +205,18 @@ class WanModel:
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
...
@@ -228,14 +234,17 @@ class WanModel:
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
class
Wan22MoeModel
(
WanModel
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
...
...
@@ -248,6 +257,10 @@ class Wan22MoeModel(WanModel):
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
and
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
...
@@ -260,3 +273,7 @@ class Wan22MoeModel(WanModel):
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
cpu_offload
and
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
lightx2v/models/runners/wan/wan_runner.py
View file @
a395cc0a
...
...
@@ -418,5 +418,5 @@ class Wan22DenseRunner(WanRunner):
return
vae_encoder_out
def
get_vae_encoder_output
(
self
,
img
):
z
=
self
.
vae_encoder
.
encode
(
img
)
z
=
self
.
vae_encoder
.
encode
(
img
,
self
.
config
)
return
z
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
a395cc0a
...
...
@@ -844,7 +844,7 @@ class Wan2_2_VAE:
self
.
dtype
=
dtype
self
.
device
=
device
mean
=
torch
.
tensor
(
self
.
mean
=
torch
.
tensor
(
[
-
0.2289
,
-
0.0052
,
...
...
@@ -898,7 +898,7 @@ class Wan2_2_VAE:
dtype
=
dtype
,
device
=
device
,
)
std
=
torch
.
tensor
(
self
.
std
=
torch
.
tensor
(
[
0.4765
,
1.0364
,
...
...
@@ -952,8 +952,8 @@ class Wan2_2_VAE:
dtype
=
dtype
,
device
=
device
,
)
self
.
scale
=
[
mean
,
1.0
/
std
]
self
.
inv_std
=
1.0
/
self
.
std
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# init model
self
.
model
=
(
_video_vae
(
...
...
@@ -968,25 +968,35 @@ class Wan2_2_VAE:
.
to
(
device
)
)
def
encode
(
self
,
videos
):
# try:
# if not isinstance(videos, list):
# raise TypeError("videos should be a list")
# with amp.autocast(dtype=self.dtype):
# return [
# self.model.encode(u.unsqueeze(0),
# self.scale).float().squeeze(0)
# for u in videos
# ]
# except TypeError as e:
# logging.info(e)
# return None
# print(1111111)
# print(self.model.encode(videos.unsqueeze(0), self.scale).float().shape)
# exit()
return
self
.
model
.
encode
(
videos
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
def
to_cpu
(
self
):
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
"cpu"
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
"cpu"
)
self
.
model
=
self
.
model
.
to
(
"cpu"
)
self
.
mean
=
self
.
mean
.
cpu
()
self
.
inv_std
=
self
.
inv_std
.
cpu
()
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
to_cuda
(
self
):
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
"cuda"
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
mean
=
self
.
mean
.
cuda
()
self
.
inv_std
=
self
.
inv_std
.
cuda
()
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
encode
(
self
,
videos
,
args
):
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cuda
()
out
=
self
.
model
.
encode
(
videos
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cpu
()
return
out
def
decode
(
self
,
zs
,
generator
,
config
):
return
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
self
.
to_cuda
()
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
self
.
to_cpu
()
return
images
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment