Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
1994ffb1
Commit
1994ffb1
authored
Aug 08, 2025
by
gushiqiao
Committed by
GitHub
Aug 08, 2025
Browse files
Support offload cache for wan2.2_vae
Support offload cache for wan2.2_vae
parents
64948a2e
83a73049
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
113 additions
and
45 deletions
+113
-45
configs/offload/block/readme.md
configs/offload/block/readme.md
+0
-1
configs/wan22/wan_ti2v_i2v_4090.json
configs/wan22/wan_ti2v_i2v_4090.json
+21
-0
configs/wan22/wan_ti2v_t2v_4090.json
configs/wan22/wan_ti2v_t2v_4090.json
+20
-0
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+42
-10
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+5
-3
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+22
-28
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+2
-2
scripts/wan/run_wan_i2v_lazy_load.sh
scripts/wan/run_wan_i2v_lazy_load.sh
+1
-1
No files found.
configs/offload/block/readme.md
deleted
100755 → 0
View file @
64948a2e
## TODO
configs/wan22/wan_ti2v_i2v_4090.json
0 → 100755
View file @
1994ffb1
{
"infer_steps"
:
50
,
"target_video_length"
:
121
,
"text_len"
:
512
,
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
"fps"
:
24
,
"use_image_encoder"
:
false
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"model"
,
"vae_offload_cache"
:
true
}
configs/wan22/wan_ti2v_t2v_4090.json
0 → 100755
View file @
1994ffb1
{
"infer_steps"
:
50
,
"target_video_length"
:
121
,
"text_len"
:
512
,
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
"fps"
:
24
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"model"
,
"vae_offload_cache"
:
true
}
lightx2v/models/runners/wan/wan_runner.py
View file @
1994ffb1
...
...
@@ -68,13 +68,13 @@ class WanRunner(DefaultRunner):
assert
clip_quant_scheme
is
not
None
tmp_clip_quant_scheme
=
clip_quant_scheme
.
split
(
"-"
)[
0
]
clip_model_name
=
f
"clip-
{
tmp_clip_quant_scheme
}
.pth"
clip_quantized_ckpt
=
find_torch_model_path
(
self
.
config
,
"clip_quantized_ckpt"
,
clip_model_name
,
tmp_clip_quant_scheme
)
clip_quantized_ckpt
=
find_torch_model_path
(
self
.
config
,
"clip_quantized_ckpt"
,
clip_model_name
)
clip_original_ckpt
=
None
else
:
clip_quantized_ckpt
=
None
clip_quant_scheme
=
None
clip_model_name
=
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
clip_original_ckpt
=
find_torch_model_path
(
self
.
config
,
"clip_original_ckpt"
,
clip_model_name
,
"original"
)
clip_original_ckpt
=
find_torch_model_path
(
self
.
config
,
"clip_original_ckpt"
,
clip_model_name
)
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
...
...
@@ -90,7 +90,7 @@ class WanRunner(DefaultRunner):
def
load_text_encoder
(
self
):
# offload config
t5_offload
=
self
.
config
.
get
(
"t5_cpu_offload"
,
False
)
t5_offload
=
self
.
config
.
get
(
"t5_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
)
)
if
t5_offload
:
t5_device
=
torch
.
device
(
"cpu"
)
else
:
...
...
@@ -103,14 +103,14 @@ class WanRunner(DefaultRunner):
assert
t5_quant_scheme
is
not
None
tmp_t5_quant_scheme
=
t5_quant_scheme
.
split
(
"-"
)[
0
]
t5_model_name
=
f
"models_t5_umt5-xxl-enc-
{
tmp_t5_quant_scheme
}
.pth"
t5_quantized_ckpt
=
find_torch_model_path
(
self
.
config
,
"t5_quantized_ckpt"
,
t5_model_name
,
tmp_t5_quant_scheme
)
t5_quantized_ckpt
=
find_torch_model_path
(
self
.
config
,
"t5_quantized_ckpt"
,
t5_model_name
)
t5_original_ckpt
=
None
tokenizer_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
t5_quantized_ckpt
),
"google/umt5-xxl"
)
else
:
t5_quant_scheme
=
None
t5_quantized_ckpt
=
None
t5_model_name
=
"models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt
=
find_torch_model_path
(
self
.
config
,
"t5_original_ckpt"
,
t5_model_name
,
"original"
)
t5_original_ckpt
=
find_torch_model_path
(
self
.
config
,
"t5_original_ckpt"
,
t5_model_name
)
tokenizer_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
t5_original_ckpt
),
"google/umt5-xxl"
)
text_encoder
=
T5EncoderModel
(
...
...
@@ -121,7 +121,7 @@ class WanRunner(DefaultRunner):
tokenizer_path
=
tokenizer_path
,
shard_fn
=
None
,
cpu_offload
=
t5_offload
,
offload_granularity
=
self
.
config
.
get
(
"t5_offload_granularity"
,
"model"
),
offload_granularity
=
self
.
config
.
get
(
"t5_offload_granularity"
,
"model"
),
# support ['model', 'block']
t5_quantized
=
t5_quantized
,
t5_quantized_ckpt
=
t5_quantized_ckpt
,
quant_scheme
=
t5_quant_scheme
,
...
...
@@ -131,12 +131,20 @@ class WanRunner(DefaultRunner):
return
text_encoders
def
load_vae_encoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"device"
:
self
.
init
_device
,
"device"
:
vae
_device
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"seq_p_group"
:
self
.
seq_p_group
,
"cpu_offload"
:
vae_offload
,
}
if
self
.
config
.
task
!=
"i2v"
:
return
None
...
...
@@ -144,11 +152,19 @@ class WanRunner(DefaultRunner):
return
WanVAE
(
**
vae_config
)
def
load_vae_decoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"device"
:
self
.
init
_device
,
"device"
:
vae
_device
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
}
if
self
.
config
.
get
(
"use_tiny_vae"
,
False
):
tiny_vae_path
=
find_torch_model_path
(
self
.
config
,
"tiny_vae_path"
,
"taew2_1.pth"
)
...
...
@@ -398,17 +414,33 @@ class Wan22DenseRunner(WanRunner):
super
().
__init__
(
config
)
def
load_vae_decoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
self
.
init_device
,
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
vae_decoder
=
Wan2_2_VAE
(
**
vae_config
)
return
vae_decoder
def
load_vae_encoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
self
.
init_device
,
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
if
self
.
config
.
task
!=
"i2v"
:
return
None
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
1994ffb1
...
...
@@ -799,11 +799,13 @@ class WanVAE:
parallel
=
False
,
use_tiling
=
False
,
seq_p_group
=
None
,
cpu_offload
=
False
,
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
parallel
=
parallel
self
.
use_tiling
=
use_tiling
self
.
cpu_offload
=
cpu_offload
mean
=
[
-
0.7571
,
...
...
@@ -940,8 +942,8 @@ class WanVAE:
return
images
def
decode
(
self
,
zs
,
generator
,
config
):
if
config
.
cpu_offload
:
def
decode
(
self
,
zs
,
**
args
):
if
self
.
cpu_offload
:
self
.
to_cuda
()
if
self
.
parallel
:
...
...
@@ -962,7 +964,7 @@ class WanVAE:
else
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
if
self
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
self
.
to_cpu
()
...
...
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
1994ffb1
...
...
@@ -619,7 +619,7 @@ class Decoder3d(nn.Module):
CausalConv3d
(
out_dim
,
12
,
3
,
padding
=
1
),
)
def
forward
(
self
,
x
,
feat_cache
=
None
,
feat_idx
=
[
0
],
first_chunk
=
False
):
def
forward
(
self
,
x
,
feat_cache
=
None
,
feat_idx
=
[
0
],
first_chunk
=
False
,
offload_cache
=
False
):
if
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
...
...
@@ -639,14 +639,24 @@ class Decoder3d(nn.Module):
for
layer
in
self
.
middle
:
if
isinstance
(
layer
,
ResidualBlock
)
and
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
x
=
layer
(
x
,
feat_cache
,
feat_idx
)
if
offload_cache
:
for
_idx
in
range
(
idx
,
feat_idx
[
0
]):
if
isinstance
(
feat_cache
[
_idx
],
torch
.
Tensor
):
feat_cache
[
_idx
]
=
feat_cache
[
_idx
].
cpu
()
else
:
x
=
layer
(
x
)
## upsamples
for
layer
in
self
.
upsamples
:
if
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
x
=
layer
(
x
,
feat_cache
,
feat_idx
,
first_chunk
)
if
offload_cache
:
for
_idx
in
range
(
idx
,
feat_idx
[
0
]):
if
isinstance
(
feat_cache
[
_idx
],
torch
.
Tensor
):
feat_cache
[
_idx
]
=
feat_cache
[
_idx
].
cpu
()
else
:
x
=
layer
(
x
)
...
...
@@ -664,7 +674,7 @@ class Decoder3d(nn.Module):
dim
=
2
,
)
x
=
layer
(
x
,
feat_cache
[
idx
])
feat_cache
[
idx
]
=
cache_x
feat_cache
[
idx
]
=
cache_x
.
cpu
()
if
offload_cache
else
cache_x
feat_idx
[
0
]
+=
1
else
:
x
=
layer
(
x
)
...
...
@@ -755,7 +765,7 @@ class WanVAE_(nn.Module):
self
.
clear_cache
()
return
mu
def
decode
(
self
,
z
,
scale
):
def
decode
(
self
,
z
,
scale
,
offload_cache
=
False
):
self
.
clear_cache
()
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
...
...
@@ -766,18 +776,9 @@ class WanVAE_(nn.Module):
for
i
in
range
(
iter_
):
self
.
_conv_idx
=
[
0
]
if
i
==
0
:
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
first_chunk
=
True
,
)
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
first_chunk
=
True
,
offload_cache
=
offload_cache
)
else
:
out_
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
)
out_
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
offload_cache
=
offload_cache
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
out
=
unpatchify
(
out
,
patch_size
=
2
)
self
.
clear_cache
()
...
...
@@ -830,18 +831,11 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
class
Wan2_2_VAE
:
def
__init__
(
self
,
z_dim
=
48
,
c_dim
=
160
,
vae_pth
=
None
,
dim_mult
=
[
1
,
2
,
4
,
4
],
temperal_downsample
=
[
False
,
True
,
True
],
dtype
=
torch
.
float
,
device
=
"cuda"
,
):
def
__init__
(
self
,
z_dim
=
48
,
c_dim
=
160
,
vae_pth
=
None
,
dim_mult
=
[
1
,
2
,
4
,
4
],
temperal_downsample
=
[
False
,
True
,
True
],
dtype
=
torch
.
float
,
device
=
"cuda"
,
cpu_offload
=
False
,
offload_cache
=
False
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
cpu_offload
=
cpu_offload
self
.
offload_cache
=
offload_cache
self
.
mean
=
torch
.
tensor
(
[
...
...
@@ -991,11 +985,11 @@ class Wan2_2_VAE:
self
.
to_cpu
()
return
out
def
decode
(
self
,
zs
,
generator
,
config
):
if
config
.
cpu_offload
:
def
decode
(
self
,
zs
,
**
args
):
if
self
.
cpu_offload
:
self
.
to_cuda
()
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
,
offload_cache
=
self
.
offload_cache
if
self
.
cpu_offload
else
False
).
float
().
clamp_
(
-
1
,
1
)
if
self
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
self
.
to_cpu
()
return
images
lightx2v/utils/utils.py
View file @
1994ffb1
...
...
@@ -258,7 +258,7 @@ def save_to_video(
raise
ValueError
(
f
"Unknown save method:
{
method
}
"
)
def
find_torch_model_path
(
config
,
ckpt_config_key
=
None
,
filename
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
]):
def
find_torch_model_path
(
config
,
ckpt_config_key
=
None
,
filename
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
,
"distill_models"
,
"distill_fp8"
,
"distill_int8"
]):
if
ckpt_config_key
and
config
.
get
(
ckpt_config_key
,
None
)
is
not
None
:
return
config
.
get
(
ckpt_config_key
)
...
...
@@ -277,7 +277,7 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
raise
FileNotFoundError
(
f
"PyTorch model file '
{
filename
}
' not found.
\n
Please download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file."
)
def
find_hf_model_path
(
config
,
model_path
,
ckpt_config_key
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
]):
def
find_hf_model_path
(
config
,
model_path
,
ckpt_config_key
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
,
"distill_models"
,
"distill_fp8"
,
"distill_int8"
]):
if
ckpt_config_key
and
config
.
get
(
ckpt_config_key
,
None
)
is
not
None
:
return
config
.
get
(
ckpt_config_key
)
...
...
scripts/wan/run_wan_i2v_lazy_load.sh
View file @
1994ffb1
...
...
@@ -5,7 +5,7 @@ lightx2v_path=
model_path
=
export
CUDA_VISIBLE_DEVICES
=
0
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment