Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
375a6f77
Commit
375a6f77
authored
Sep 01, 2025
by
gushiqiao
Committed by
GitHub
Sep 01, 2025
Browse files
[Feat] support tae for wan_2_2 (#275)
parent
f2e1def0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
182 additions
and
68 deletions
+182
-68
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+15
-48
lightx2v/models/video_encoders/hf/tae.py
lightx2v/models/video_encoders/hf/tae.py
+41
-18
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
+126
-2
No files found.
lightx2v/models/runners/wan/wan_runner.py
View file @
375a6f77
...
...
@@ -23,7 +23,7 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
WanVAE_tiny
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
Wan2_2_VAE_tiny
,
WanVAE_tiny
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
...
...
@@ -34,6 +34,10 @@ from lightx2v.utils.utils import best_output_size, cache_video
class
WanRunner
(
DefaultRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
vae_cls
=
WanVAE
self
.
tiny_vae_cls
=
WanVAE_tiny
self
.
vae_name
=
"Wan2.1_VAE.pth"
self
.
tiny_vae_name
=
"taew2_1.pth"
def
load_transformer
(
self
):
model
=
WanModel
(
...
...
@@ -133,7 +137,7 @@ class WanRunner(DefaultRunner):
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
self
.
vae_name
),
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
...
...
@@ -143,7 +147,7 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
task
not
in
[
"i2v"
,
"flf2v"
,
"vace"
]:
return
None
else
:
return
WanVAE
(
**
vae_config
)
return
self
.
vae_cls
(
**
vae_config
)
def
load_vae_decoder
(
self
):
# offload config
...
...
@@ -154,7 +158,7 @@ class WanRunner(DefaultRunner):
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
self
.
vae_name
),
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
...
...
@@ -162,10 +166,10 @@ class WanRunner(DefaultRunner):
"dtype"
:
GET_DTYPE
(),
}
if
self
.
config
.
get
(
"use_tiny_vae"
,
False
):
tiny_vae_path
=
find_torch_model_path
(
self
.
config
,
"tiny_vae_path"
,
"taew2_1.pth"
)
vae_decoder
=
WanVAE_tiny
(
vae_pth
=
tiny_vae_path
,
device
=
self
.
init_device
,
need_scaled
=
self
.
config
.
get
(
"need_scaled"
,
False
)).
to
(
"cuda"
)
tiny_vae_path
=
find_torch_model_path
(
self
.
config
,
"tiny_vae_path"
,
self
.
tiny_vae_name
)
vae_decoder
=
self
.
tiny_vae_cls
(
vae_pth
=
tiny_vae_path
,
device
=
self
.
init_device
,
need_scaled
=
self
.
config
.
get
(
"need_scaled"
,
False
)).
to
(
"cuda"
)
else
:
vae_decoder
=
WanVAE
(
**
vae_config
)
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
return
vae_decoder
def
load_vae
(
self
):
...
...
@@ -430,47 +434,10 @@ class Wan22DenseRunner(WanRunner):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
vae_encoder_need_img_original
=
True
def
load_vae_decoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
"dtype"
:
GET_DTYPE
(),
}
vae_decoder
=
Wan2_2_VAE
(
**
vae_config
)
return
vae_decoder
def
load_vae_encoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
"dtype"
:
GET_DTYPE
(),
}
if
self
.
config
.
task
not
in
[
"i2v"
,
"flf2v"
]:
return
None
else
:
return
Wan2_2_VAE
(
**
vae_config
)
def
load_vae
(
self
):
vae_encoder
=
self
.
load_vae_encoder
()
vae_decoder
=
self
.
load_vae_decoder
()
return
vae_encoder
,
vae_decoder
self
.
vae_cls
=
Wan2_2_VAE
self
.
tiny_vae_cls
=
Wan2_2_VAE_tiny
self
.
vae_name
=
"Wan2.2_VAE.pth"
self
.
tiny_vae_name
=
"taew2_2.pth"
def
run_vae_encoder
(
self
,
img
):
max_area
=
self
.
config
.
target_height
*
self
.
config
.
target_width
...
...
lightx2v/models/video_encoders/hf/tae.py
100644 → 100755
View file @
375a6f77
import
os
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
"""
from
collections
import
namedtuple
import
torch
...
...
@@ -6,8 +11,6 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
tqdm.auto
import
tqdm
os
.
environ
[
"PYTORCH_CUDA_ALLOC_CONF"
]
=
"max_split_size_mb:32,expandable_segments:True"
DecoderResult
=
namedtuple
(
"DecoderResult"
,
(
"frame"
,
"memory"
))
TWorkItem
=
namedtuple
(
"TWorkItem"
,
(
"input_tensor"
,
"block_index"
))
...
...
@@ -149,27 +152,31 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
xt
=
b
(
xt
)
# add successor to work queue
work_queue
.
insert
(
0
,
TWorkItem
(
xt
,
i
+
1
))
progress_bar
.
close
()
x
=
torch
.
stack
(
out
,
1
)
return
x
class
TAEHV
(
nn
.
Module
):
latent_channels
=
16
image_channels
=
3
def
__init__
(
self
,
checkpoint_path
=
"taehv.pth"
,
decoder_time_upscale
=
(
True
,
True
),
decoder_space_upscale
=
(
True
,
True
,
True
)):
def
__init__
(
self
,
checkpoint_path
=
"taehv.pth"
,
decoder_time_upscale
=
(
True
,
True
),
decoder_space_upscale
=
(
True
,
True
,
True
),
patch_size
=
1
,
latent_channels
=
16
):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
patch_size: input/output pixelshuffle patch-size for this model.
latent_channels: number of latent channels (z dim) for this model.
"""
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
latent_channels
=
latent_channels
self
.
image_channels
=
3
self
.
is_cogvideox
=
checkpoint_path
is
not
None
and
"taecvx"
in
checkpoint_path
if
checkpoint_path
is
not
None
and
"taew2_2"
in
checkpoint_path
:
self
.
patch_size
,
self
.
latent_channels
=
2
,
48
self
.
encoder
=
nn
.
Sequential
(
conv
(
TAEHV
.
image_channels
,
64
),
conv
(
self
.
image_channels
*
self
.
patch_size
**
2
,
64
),
nn
.
ReLU
(
inplace
=
True
),
TPool
(
64
,
2
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
...
...
@@ -186,13 +193,13 @@ class TAEHV(nn.Module):
MemBlock
(
64
,
64
),
MemBlock
(
64
,
64
),
MemBlock
(
64
,
64
),
conv
(
64
,
TAEHV
.
latent_channels
),
conv
(
64
,
self
.
latent_channels
),
)
n_f
=
[
256
,
128
,
64
,
64
]
self
.
frames_to_trim
=
2
**
sum
(
decoder_time_upscale
)
-
1
self
.
decoder
=
nn
.
Sequential
(
Clamp
(),
conv
(
TAEHV
.
latent_channels
,
n_f
[
0
]),
conv
(
self
.
latent_channels
,
n_f
[
0
]),
nn
.
ReLU
(
inplace
=
True
),
MemBlock
(
n_f
[
0
],
n_f
[
0
]),
MemBlock
(
n_f
[
0
],
n_f
[
0
]),
...
...
@@ -213,7 +220,7 @@ class TAEHV(nn.Module):
TGrow
(
n_f
[
2
],
2
if
decoder_time_upscale
[
1
]
else
1
),
conv
(
n_f
[
2
],
n_f
[
3
],
bias
=
False
),
nn
.
ReLU
(
inplace
=
True
),
conv
(
n_f
[
3
],
TAEHV
.
image_channels
),
conv
(
n_f
[
3
],
self
.
image_channels
*
self
.
patch_size
**
2
),
)
if
checkpoint_path
is
not
None
:
self
.
load_state_dict
(
self
.
patch_tgrow_layers
(
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)))
...
...
@@ -243,6 +250,13 @@ class TAEHV(nn.Module):
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
if
self
.
patch_size
>
1
:
x
=
F
.
pixel_unshuffle
(
x
,
self
.
patch_size
)
if
x
.
shape
[
1
]
%
4
!=
0
:
# pad at end to multiple of 4
n_pad
=
4
-
x
.
shape
[
1
]
%
4
padding
=
x
[:,
-
1
:].
repeat_interleave
(
n_pad
,
dim
=
1
)
x
=
torch
.
cat
([
x
,
padding
],
1
)
return
apply_model_with_memblocks
(
self
.
encoder
,
x
,
parallel
,
show_progress_bar
)
def
decode_video
(
self
,
x
,
parallel
=
True
,
show_progress_bar
=
True
):
...
...
@@ -255,16 +269,23 @@ class TAEHV(nn.Module):
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
skip_trim
=
self
.
is_cogvideox
and
x
.
shape
[
1
]
%
2
==
0
x
=
apply_model_with_memblocks
(
self
.
decoder
,
x
,
parallel
,
show_progress_bar
)
x
=
x
.
clamp_
(
0
,
1
)
if
self
.
patch_size
>
1
:
x
=
F
.
pixel_shuffle
(
x
,
self
.
patch_size
)
if
skip_trim
:
# skip trimming for cogvideox to make frame counts match.
# this still doesn't have correct temporal alignment for certain frame counts
# (cogvideox seems to pad at the start?), but for multiple-of-4 it's fine.
return
x
return
x
[:,
self
.
frames_to_trim
:]
def
forward
(
self
,
x
):
return
self
.
c
(
x
)
@
torch
.
no_grad
()
def
main
():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import
os
import
sys
import
cv2
# no highly esteemed deed is commemorated here
...
...
@@ -300,8 +321,10 @@ def main():
dev
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"mps"
if
torch
.
backends
.
mps
.
is_available
()
else
"cpu"
)
dtype
=
torch
.
float16
print
(
"Using device"
,
dev
,
"and dtype"
,
dtype
)
taehv
=
TAEHV
().
to
(
dev
,
dtype
)
checkpoint_path
=
os
.
getenv
(
"TAEHV_CHECKPOINT_PATH"
,
"taehv.pth"
)
checkpoint_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
checkpoint_path
))[
0
]
print
(
f
"Using device
\033
[31m
{
dev
}
\033
[0m, dtype
\033
[32m
{
dtype
}
\033
[0m, checkpoint
\033
[34m
{
checkpoint_name
}
\033
[0m (
{
checkpoint_path
}
)"
)
taehv
=
TAEHV
(
checkpoint_path
=
checkpoint_path
).
to
(
dev
,
dtype
)
for
video_path
in
sys
.
argv
[
1
:]:
print
(
f
"Processing
{
video_path
}
..."
)
video_in
=
VideoTensorReader
(
video_path
)
...
...
@@ -322,7 +345,7 @@ def main():
print
(
f
" Encoded
{
video_path
}
->
{
vid_enc
.
shape
}
. Decoding..."
)
vid_dec
=
taehv
.
decode_video
(
vid_enc
,
parallel
=
False
)
print
(
f
" Decoded
{
video_path
}
->
{
vid_dec
.
shape
}
"
)
video_out_path
=
video_path
+
".reconstructed_by_
taehv
.mp4"
video_out_path
=
video_path
+
f
".reconstructed_by_
{
checkpoint_name
}
.mp4"
video_out
=
VideoTensorWriter
(
video_out_path
,
(
vid_dec
.
shape
[
-
1
],
vid_dec
.
shape
[
-
2
]),
fps
=
int
(
round
(
video_in
.
fps
)))
for
frame
in
vid_dec
.
clamp_
(
0
,
1
).
mul_
(
255
).
round_
().
byte
().
cpu
()[
0
]:
video_out
.
write
(
frame
)
...
...
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
View file @
375a6f77
...
...
@@ -18,10 +18,8 @@ class WanVAE_tiny(nn.Module):
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
taehv
=
TAEHV
(
vae_pth
).
to
(
self
.
dtype
)
self
.
temperal_downsample
=
[
True
,
True
,
False
]
self
.
config
=
DotDict
(
scaling_factor
=
1.0
,
latents_mean
=
torch
.
zeros
(
16
),
z_dim
=
16
,
latents_std
=
torch
.
ones
(
16
))
self
.
need_scaled
=
need_scaled
# temp
if
self
.
need_scaled
:
self
.
latents_mean
=
[
-
0.7571
,
...
...
@@ -75,3 +73,129 @@ class WanVAE_tiny(nn.Module):
# low-memory, set parallel=True for faster + higher memory
return
self
.
taehv
.
decode_video
(
latents
.
transpose
(
1
,
2
).
to
(
self
.
dtype
),
parallel
=
False
).
transpose
(
1
,
2
).
mul_
(
2
).
sub_
(
1
)
class
Wan2_2_VAE_tiny
(
nn
.
Module
):
def
__init__
(
self
,
vae_pth
=
"taew2_2.pth"
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
,
need_scaled
=
False
):
super
().
__init__
()
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
taehv
=
TAEHV
(
vae_pth
).
to
(
self
.
dtype
)
self
.
need_scaled
=
need_scaled
if
self
.
need_scaled
:
self
.
latents_mean
=
[
-
0.2289
,
-
0.0052
,
-
0.1323
,
-
0.2339
,
-
0.2799
,
0.0174
,
0.1838
,
0.1557
,
-
0.1382
,
0.0542
,
0.2813
,
0.0891
,
0.1570
,
-
0.0098
,
0.0375
,
-
0.1825
,
-
0.2246
,
-
0.1207
,
-
0.0698
,
0.5109
,
0.2665
,
-
0.2108
,
-
0.2158
,
0.2502
,
-
0.2055
,
-
0.0322
,
0.1109
,
0.1567
,
-
0.0729
,
0.0899
,
-
0.2799
,
-
0.1230
,
-
0.0313
,
-
0.1649
,
0.0117
,
0.0723
,
-
0.2839
,
-
0.2083
,
-
0.0520
,
0.3748
,
0.0152
,
0.1957
,
0.1433
,
-
0.2944
,
0.3573
,
-
0.0548
,
-
0.1681
,
-
0.0667
,
]
self
.
latents_std
=
[
0.4765
,
1.0364
,
0.4514
,
1.1677
,
0.5313
,
0.4990
,
0.4818
,
0.5013
,
0.8158
,
1.0344
,
0.5894
,
1.0901
,
0.6885
,
0.6165
,
0.8454
,
0.4978
,
0.5759
,
0.3523
,
0.7135
,
0.6804
,
0.5833
,
1.4146
,
0.8986
,
0.5659
,
0.7069
,
0.5338
,
0.4889
,
0.4917
,
0.4069
,
0.4999
,
0.6866
,
0.4093
,
0.5709
,
0.6065
,
0.6415
,
0.4944
,
0.5726
,
1.2042
,
0.5458
,
1.6887
,
0.3971
,
1.0600
,
0.3943
,
0.5537
,
0.5444
,
0.4089
,
0.7468
,
0.7744
,
]
self
.
z_dim
=
48
@
peak_memory_decorator
@
torch
.
no_grad
()
def
decode
(
self
,
latents
):
latents
=
latents
.
unsqueeze
(
0
)
if
self
.
need_scaled
:
latents_mean
=
torch
.
tensor
(
self
.
latents_mean
).
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
).
to
(
latents
.
device
,
latents
.
dtype
)
latents_std
=
1.0
/
torch
.
tensor
(
self
.
latents_std
).
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
).
to
(
latents
.
device
,
latents
.
dtype
)
latents
=
latents
/
latents_std
+
latents_mean
# low-memory, set parallel=True for faster + higher memory
return
self
.
taehv
.
decode_video
(
latents
.
transpose
(
1
,
2
).
to
(
self
.
dtype
),
parallel
=
False
).
transpose
(
1
,
2
).
mul_
(
2
).
sub_
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment