Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
8d32295d
Commit
8d32295d
authored
Aug 09, 2025
by
wangshankun
Browse files
Merge branch 'main' of
https://github.com/ModelTC/LightX2V
into main
parents
daa06243
1994ffb1
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
113 additions
and
45 deletions
+113
-45
configs/offload/block/readme.md
configs/offload/block/readme.md
+0
-1
configs/wan22/wan_ti2v_i2v_4090.json
configs/wan22/wan_ti2v_i2v_4090.json
+21
-0
configs/wan22/wan_ti2v_t2v_4090.json
configs/wan22/wan_ti2v_t2v_4090.json
+20
-0
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+42
-10
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+5
-3
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+22
-28
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+2
-2
scripts/wan/run_wan_i2v_lazy_load.sh
scripts/wan/run_wan_i2v_lazy_load.sh
+1
-1
No files found.
configs/offload/block/readme.md
deleted
100755 → 0
View file @
daa06243
## TODO
configs/wan22/wan_ti2v_i2v_4090.json
0 → 100755
View file @
8d32295d
{
"infer_steps"
:
50
,
"target_video_length"
:
121
,
"text_len"
:
512
,
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
"fps"
:
24
,
"use_image_encoder"
:
false
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"model"
,
"vae_offload_cache"
:
true
}
configs/wan22/wan_ti2v_t2v_4090.json
0 → 100755
View file @
8d32295d
{
"infer_steps"
:
50
,
"target_video_length"
:
121
,
"text_len"
:
512
,
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
"fps"
:
24
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"model"
,
"vae_offload_cache"
:
true
}
lightx2v/models/runners/wan/wan_runner.py
View file @
8d32295d
...
...
@@ -68,13 +68,13 @@ class WanRunner(DefaultRunner):
assert
clip_quant_scheme
is
not
None
tmp_clip_quant_scheme
=
clip_quant_scheme
.
split
(
"-"
)[
0
]
clip_model_name
=
f
"clip-
{
tmp_clip_quant_scheme
}
.pth"
clip_quantized_ckpt
=
find_torch_model_path
(
self
.
config
,
"clip_quantized_ckpt"
,
clip_model_name
,
tmp_clip_quant_scheme
)
clip_quantized_ckpt
=
find_torch_model_path
(
self
.
config
,
"clip_quantized_ckpt"
,
clip_model_name
)
clip_original_ckpt
=
None
else
:
clip_quantized_ckpt
=
None
clip_quant_scheme
=
None
clip_model_name
=
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
clip_original_ckpt
=
find_torch_model_path
(
self
.
config
,
"clip_original_ckpt"
,
clip_model_name
,
"original"
)
clip_original_ckpt
=
find_torch_model_path
(
self
.
config
,
"clip_original_ckpt"
,
clip_model_name
)
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
...
...
@@ -90,7 +90,7 @@ class WanRunner(DefaultRunner):
def
load_text_encoder
(
self
):
# offload config
t5_offload
=
self
.
config
.
get
(
"t5_cpu_offload"
,
False
)
t5_offload
=
self
.
config
.
get
(
"t5_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
)
)
if
t5_offload
:
t5_device
=
torch
.
device
(
"cpu"
)
else
:
...
...
@@ -103,14 +103,14 @@ class WanRunner(DefaultRunner):
assert
t5_quant_scheme
is
not
None
tmp_t5_quant_scheme
=
t5_quant_scheme
.
split
(
"-"
)[
0
]
t5_model_name
=
f
"models_t5_umt5-xxl-enc-
{
tmp_t5_quant_scheme
}
.pth"
t5_quantized_ckpt
=
find_torch_model_path
(
self
.
config
,
"t5_quantized_ckpt"
,
t5_model_name
,
tmp_t5_quant_scheme
)
t5_quantized_ckpt
=
find_torch_model_path
(
self
.
config
,
"t5_quantized_ckpt"
,
t5_model_name
)
t5_original_ckpt
=
None
tokenizer_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
t5_quantized_ckpt
),
"google/umt5-xxl"
)
else
:
t5_quant_scheme
=
None
t5_quantized_ckpt
=
None
t5_model_name
=
"models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt
=
find_torch_model_path
(
self
.
config
,
"t5_original_ckpt"
,
t5_model_name
,
"original"
)
t5_original_ckpt
=
find_torch_model_path
(
self
.
config
,
"t5_original_ckpt"
,
t5_model_name
)
tokenizer_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
t5_original_ckpt
),
"google/umt5-xxl"
)
text_encoder
=
T5EncoderModel
(
...
...
@@ -121,7 +121,7 @@ class WanRunner(DefaultRunner):
tokenizer_path
=
tokenizer_path
,
shard_fn
=
None
,
cpu_offload
=
t5_offload
,
offload_granularity
=
self
.
config
.
get
(
"t5_offload_granularity"
,
"model"
),
offload_granularity
=
self
.
config
.
get
(
"t5_offload_granularity"
,
"model"
),
# support ['model', 'block']
t5_quantized
=
t5_quantized
,
t5_quantized_ckpt
=
t5_quantized_ckpt
,
quant_scheme
=
t5_quant_scheme
,
...
...
@@ -131,12 +131,20 @@ class WanRunner(DefaultRunner):
return
text_encoders
def
load_vae_encoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"device"
:
self
.
init
_device
,
"device"
:
vae
_device
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"seq_p_group"
:
self
.
seq_p_group
,
"cpu_offload"
:
vae_offload
,
}
if
self
.
config
.
task
!=
"i2v"
:
return
None
...
...
@@ -144,11 +152,19 @@ class WanRunner(DefaultRunner):
return
WanVAE
(
**
vae_config
)
def
load_vae_decoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"device"
:
self
.
init
_device
,
"device"
:
vae
_device
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
}
if
self
.
config
.
get
(
"use_tiny_vae"
,
False
):
tiny_vae_path
=
find_torch_model_path
(
self
.
config
,
"tiny_vae_path"
,
"taew2_1.pth"
)
...
...
@@ -398,17 +414,33 @@ class Wan22DenseRunner(WanRunner):
super
().
__init__
(
config
)
def
load_vae_decoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
self
.
init_device
,
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
vae_decoder
=
Wan2_2_VAE
(
**
vae_config
)
return
vae_decoder
def
load_vae_encoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
self
.
init_device
,
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
if
self
.
config
.
task
!=
"i2v"
:
return
None
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
8d32295d
...
...
@@ -797,11 +797,13 @@ class WanVAE:
parallel
=
False
,
use_tiling
=
False
,
seq_p_group
=
None
,
cpu_offload
=
False
,
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
parallel
=
parallel
self
.
use_tiling
=
use_tiling
self
.
cpu_offload
=
cpu_offload
mean
=
[
-
0.7571
,
...
...
@@ -938,8 +940,8 @@ class WanVAE:
return
images
def
decode
(
self
,
zs
,
generator
,
config
):
if
config
.
cpu_offload
:
def
decode
(
self
,
zs
,
**
args
):
if
self
.
cpu_offload
:
self
.
to_cuda
()
if
self
.
parallel
:
...
...
@@ -960,7 +962,7 @@ class WanVAE:
else
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
if
self
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
self
.
to_cpu
()
...
...
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
8d32295d
...
...
@@ -619,7 +619,7 @@ class Decoder3d(nn.Module):
CausalConv3d
(
out_dim
,
12
,
3
,
padding
=
1
),
)
def
forward
(
self
,
x
,
feat_cache
=
None
,
feat_idx
=
[
0
],
first_chunk
=
False
):
def
forward
(
self
,
x
,
feat_cache
=
None
,
feat_idx
=
[
0
],
first_chunk
=
False
,
offload_cache
=
False
):
if
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
...
...
@@ -639,14 +639,24 @@ class Decoder3d(nn.Module):
for
layer
in
self
.
middle
:
if
isinstance
(
layer
,
ResidualBlock
)
and
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
x
=
layer
(
x
,
feat_cache
,
feat_idx
)
if
offload_cache
:
for
_idx
in
range
(
idx
,
feat_idx
[
0
]):
if
isinstance
(
feat_cache
[
_idx
],
torch
.
Tensor
):
feat_cache
[
_idx
]
=
feat_cache
[
_idx
].
cpu
()
else
:
x
=
layer
(
x
)
## upsamples
for
layer
in
self
.
upsamples
:
if
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
x
=
layer
(
x
,
feat_cache
,
feat_idx
,
first_chunk
)
if
offload_cache
:
for
_idx
in
range
(
idx
,
feat_idx
[
0
]):
if
isinstance
(
feat_cache
[
_idx
],
torch
.
Tensor
):
feat_cache
[
_idx
]
=
feat_cache
[
_idx
].
cpu
()
else
:
x
=
layer
(
x
)
...
...
@@ -664,7 +674,7 @@ class Decoder3d(nn.Module):
dim
=
2
,
)
x
=
layer
(
x
,
feat_cache
[
idx
])
feat_cache
[
idx
]
=
cache_x
feat_cache
[
idx
]
=
cache_x
.
cpu
()
if
offload_cache
else
cache_x
feat_idx
[
0
]
+=
1
else
:
x
=
layer
(
x
)
...
...
@@ -755,7 +765,7 @@ class WanVAE_(nn.Module):
self
.
clear_cache
()
return
mu
def
decode
(
self
,
z
,
scale
):
def
decode
(
self
,
z
,
scale
,
offload_cache
=
False
):
self
.
clear_cache
()
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
...
...
@@ -766,18 +776,9 @@ class WanVAE_(nn.Module):
for
i
in
range
(
iter_
):
self
.
_conv_idx
=
[
0
]
if
i
==
0
:
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
first_chunk
=
True
,
)
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
first_chunk
=
True
,
offload_cache
=
offload_cache
)
else
:
out_
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
)
out_
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
offload_cache
=
offload_cache
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
out
=
unpatchify
(
out
,
patch_size
=
2
)
self
.
clear_cache
()
...
...
@@ -830,18 +831,11 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
class
Wan2_2_VAE
:
def
__init__
(
self
,
z_dim
=
48
,
c_dim
=
160
,
vae_pth
=
None
,
dim_mult
=
[
1
,
2
,
4
,
4
],
temperal_downsample
=
[
False
,
True
,
True
],
dtype
=
torch
.
float
,
device
=
"cuda"
,
):
def
__init__
(
self
,
z_dim
=
48
,
c_dim
=
160
,
vae_pth
=
None
,
dim_mult
=
[
1
,
2
,
4
,
4
],
temperal_downsample
=
[
False
,
True
,
True
],
dtype
=
torch
.
float
,
device
=
"cuda"
,
cpu_offload
=
False
,
offload_cache
=
False
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
cpu_offload
=
cpu_offload
self
.
offload_cache
=
offload_cache
self
.
mean
=
torch
.
tensor
(
[
...
...
@@ -991,11 +985,11 @@ class Wan2_2_VAE:
self
.
to_cpu
()
return
out
def
decode
(
self
,
zs
,
generator
,
config
):
if
config
.
cpu_offload
:
def
decode
(
self
,
zs
,
**
args
):
if
self
.
cpu_offload
:
self
.
to_cuda
()
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
config
.
cpu_offload
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
,
offload_cache
=
self
.
offload_cache
if
self
.
cpu_offload
else
False
).
float
().
clamp_
(
-
1
,
1
)
if
self
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
self
.
to_cpu
()
return
images
lightx2v/utils/utils.py
View file @
8d32295d
...
...
@@ -258,7 +258,7 @@ def save_to_video(
raise
ValueError
(
f
"Unknown save method:
{
method
}
"
)
def
find_torch_model_path
(
config
,
ckpt_config_key
=
None
,
filename
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
]):
def
find_torch_model_path
(
config
,
ckpt_config_key
=
None
,
filename
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
,
"distill_models"
,
"distill_fp8"
,
"distill_int8"
]):
if
ckpt_config_key
and
config
.
get
(
ckpt_config_key
,
None
)
is
not
None
:
return
config
.
get
(
ckpt_config_key
)
...
...
@@ -277,7 +277,7 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
raise
FileNotFoundError
(
f
"PyTorch model file '
{
filename
}
' not found.
\n
Please download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file."
)
def
find_hf_model_path
(
config
,
model_path
,
ckpt_config_key
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
]):
def
find_hf_model_path
(
config
,
model_path
,
ckpt_config_key
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
,
"distill_models"
,
"distill_fp8"
,
"distill_int8"
]):
if
ckpt_config_key
and
config
.
get
(
ckpt_config_key
,
None
)
is
not
None
:
return
config
.
get
(
ckpt_config_key
)
...
...
scripts/wan/run_wan_i2v_lazy_load.sh
View file @
8d32295d
...
...
@@ -5,7 +5,7 @@ lightx2v_path=
model_path
=
export
CUDA_VISIBLE_DEVICES
=
0
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment