Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
89ce2aa6
Commit
89ce2aa6
authored
Sep 01, 2025
by
gushiqiao
Committed by
GitHub
Sep 01, 2025
Browse files
[feat] support lightvae (#272)
parent
cb359e19
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
6 deletions
+51
-6
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-4
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
+50
-2
No files found.
lightx2v/models/runners/wan/wan_runner.py
View file @
89ce2aa6
...
@@ -163,10 +163,7 @@ class WanRunner(DefaultRunner):
...
@@ -163,10 +163,7 @@ class WanRunner(DefaultRunner):
}
}
if
self
.
config
.
get
(
"use_tiny_vae"
,
False
):
if
self
.
config
.
get
(
"use_tiny_vae"
,
False
):
tiny_vae_path
=
find_torch_model_path
(
self
.
config
,
"tiny_vae_path"
,
"taew2_1.pth"
)
tiny_vae_path
=
find_torch_model_path
(
self
.
config
,
"tiny_vae_path"
,
"taew2_1.pth"
)
vae_decoder
=
WanVAE_tiny
(
vae_decoder
=
WanVAE_tiny
(
vae_pth
=
tiny_vae_path
,
device
=
self
.
init_device
,
need_scaled
=
self
.
config
.
get
(
"need_scaled"
,
False
)).
to
(
"cuda"
)
vae_pth
=
tiny_vae_path
,
device
=
self
.
init_device
,
).
to
(
"cuda"
)
else
:
else
:
vae_decoder
=
WanVAE
(
**
vae_config
)
vae_decoder
=
WanVAE
(
**
vae_config
)
return
vae_decoder
return
vae_decoder
...
...
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
100644 → 100755
View file @
89ce2aa6
...
@@ -12,18 +12,66 @@ class DotDict(dict):
...
@@ -12,18 +12,66 @@ class DotDict(dict):
class
WanVAE_tiny
(
nn
.
Module
):
class
WanVAE_tiny
(
nn
.
Module
):
def
__init__
(
self
,
vae_pth
=
"taew2_1.pth"
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
):
def
__init__
(
self
,
vae_pth
=
"taew2_1.pth"
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
,
need_scaled
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
taehv
=
TAEHV
(
vae_pth
).
to
(
self
.
dtype
)
self
.
taehv
=
TAEHV
(
vae_pth
).
to
(
self
.
dtype
)
self
.
temperal_downsample
=
[
True
,
True
,
False
]
self
.
temperal_downsample
=
[
True
,
True
,
False
]
self
.
config
=
DotDict
(
scaling_factor
=
1.0
,
latents_mean
=
torch
.
zeros
(
16
),
z_dim
=
16
,
latents_std
=
torch
.
ones
(
16
))
self
.
config
=
DotDict
(
scaling_factor
=
1.0
,
latents_mean
=
torch
.
zeros
(
16
),
z_dim
=
16
,
latents_std
=
torch
.
ones
(
16
))
self
.
need_scaled
=
need_scaled
# temp
if
self
.
need_scaled
:
self
.
latents_mean
=
[
-
0.7571
,
-
0.7089
,
-
0.9113
,
0.1075
,
-
0.1745
,
0.9653
,
-
0.1517
,
1.5508
,
0.4134
,
-
0.0715
,
0.5517
,
-
0.3632
,
-
0.1922
,
-
0.9497
,
0.2503
,
-
0.2921
,
]
self
.
latents_std
=
[
2.8184
,
1.4541
,
2.3275
,
2.6558
,
1.2196
,
1.7708
,
2.6052
,
2.0743
,
3.2687
,
2.1526
,
2.8652
,
1.5579
,
1.6382
,
1.1253
,
2.8251
,
1.9160
,
]
self
.
z_dim
=
16
@
peak_memory_decorator
@
peak_memory_decorator
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
decode
(
self
,
latents
):
def
decode
(
self
,
latents
):
latents
=
latents
.
unsqueeze
(
0
)
latents
=
latents
.
unsqueeze
(
0
)
n
,
c
,
t
,
h
,
w
=
latents
.
shape
if
self
.
need_scaled
:
latents_mean
=
torch
.
tensor
(
self
.
latents_mean
).
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
).
to
(
latents
.
device
,
latents
.
dtype
)
latents_std
=
1.0
/
torch
.
tensor
(
self
.
latents_std
).
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
).
to
(
latents
.
device
,
latents
.
dtype
)
latents
=
latents
/
latents_std
+
latents_mean
# low-memory, set parallel=True for faster + higher memory
# low-memory, set parallel=True for faster + higher memory
return
self
.
taehv
.
decode_video
(
latents
.
transpose
(
1
,
2
).
to
(
self
.
dtype
),
parallel
=
False
).
transpose
(
1
,
2
).
mul_
(
2
).
sub_
(
1
)
return
self
.
taehv
.
decode_video
(
latents
.
transpose
(
1
,
2
).
to
(
self
.
dtype
),
parallel
=
False
).
transpose
(
1
,
2
).
mul_
(
2
).
sub_
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment