Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c93c756c
Commit
c93c756c
authored
Aug 28, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Aug 28, 2025
Browse files
Support vae encode dist infer & Remove approximate_patch vae for its bad precision. (#255)
parent
bba65ffd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
197 deletions
+85
-197
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
+0
-0
lightx2v/models/video_encoders/hf/wan/dist/distributed_env.py
...tx2v/models/video_encoders/hf/wan/dist/distributed_env.py
+0
-72
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
+0
-96
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+85
-29
No files found.
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
deleted
100755 → 0
View file @
bba65ffd
lightx2v/models/video_encoders/hf/wan/dist/distributed_env.py
deleted
100755 → 0
View file @
bba65ffd
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/distributed_env.py
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
import
os
class
DistributedEnv
:
_vae_group
=
None
_local_rank
=
None
_world_size
=
None
@
classmethod
def
initialize
(
cls
,
vae_group
:
ProcessGroup
):
if
vae_group
is
None
:
cls
.
_vae_group
=
dist
.
group
.
WORLD
else
:
cls
.
_vae_group
=
vae_group
cls
.
_local_rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
0
))
# FIXME: in ray all local_rank is 0
cls
.
_rank_mapping
=
None
cls
.
_init_rank_mapping
()
@
classmethod
def
get_vae_group
(
cls
)
->
ProcessGroup
:
if
cls
.
_vae_group
is
None
:
raise
RuntimeError
(
"DistributedEnv not initialized. Call initialize() first."
)
return
cls
.
_vae_group
@
classmethod
def
get_global_rank
(
cls
)
->
int
:
return
dist
.
get_rank
()
@
classmethod
def
_init_rank_mapping
(
cls
):
"""Initialize the mapping between group ranks and global ranks"""
if
cls
.
_rank_mapping
is
None
:
# Get all ranks in the group
ranks
=
[
None
]
*
cls
.
get_group_world_size
()
dist
.
all_gather_object
(
ranks
,
cls
.
get_global_rank
(),
group
=
cls
.
get_vae_group
())
cls
.
_rank_mapping
=
ranks
@
classmethod
def
get_global_rank_from_group_rank
(
cls
,
group_rank
:
int
)
->
int
:
"""Convert a rank in VAE group to global rank using cached mapping.
Args:
group_rank: The rank in VAE group
Returns:
The corresponding global rank
Raises:
RuntimeError: If the group_rank is invalid
"""
if
cls
.
_rank_mapping
is
None
:
cls
.
_init_rank_mapping
()
if
group_rank
<
0
or
group_rank
>=
cls
.
get_group_world_size
():
raise
RuntimeError
(
f
"Invalid group rank:
{
group_rank
}
. Must be in range [0,
{
cls
.
get_group_world_size
()
-
1
}
]"
)
return
cls
.
_rank_mapping
[
group_rank
]
@
classmethod
def
get_rank_in_vae_group
(
cls
)
->
int
:
return
dist
.
get_rank
(
cls
.
get_vae_group
())
@
classmethod
def
get_group_world_size
(
cls
)
->
int
:
return
dist
.
get_world_size
(
cls
.
get_vae_group
())
@
classmethod
def
get_local_rank
(
cls
)
->
int
:
return
cls
.
_local_rank
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
deleted
100755 → 0
View file @
bba65ffd
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/split_gather.py
import
torch
import
torch.distributed
as
dist
from
lightx2v.models.video_encoders.hf.wan.dist.distributed_env
import
DistributedEnv
def
_gather
(
patch_hidden_state
,
dim
=-
1
,
group
=
None
):
group_world_size
=
DistributedEnv
.
get_group_world_size
()
local_rank
=
DistributedEnv
.
get_local_rank
()
patch_height_list
=
[
torch
.
empty
([
1
],
dtype
=
torch
.
int64
,
device
=
f
"cuda:
{
local_rank
}
"
)
for
_
in
range
(
group_world_size
)]
dist
.
all_gather
(
patch_height_list
,
torch
.
tensor
(
[
patch_hidden_state
.
shape
[
3
]],
dtype
=
torch
.
int64
,
device
=
f
"cuda:
{
local_rank
}
"
),
group
=
DistributedEnv
.
get_vae_group
()
)
patch_hidden_state_list
=
[
torch
.
zeros
(
[
patch_hidden_state
.
shape
[
0
],
patch_hidden_state
.
shape
[
1
],
patch_hidden_state
.
shape
[
2
],
patch_height_list
[
i
].
item
(),
patch_hidden_state
.
shape
[
4
]],
dtype
=
patch_hidden_state
.
dtype
,
device
=
f
"cuda:
{
local_rank
}
"
,
requires_grad
=
patch_hidden_state
.
requires_grad
)
for
i
in
range
(
group_world_size
)
]
dist
.
all_gather
(
patch_hidden_state_list
,
patch_hidden_state
.
contiguous
(),
group
=
DistributedEnv
.
get_vae_group
()
)
output
=
torch
.
cat
(
patch_hidden_state_list
,
dim
=
3
)
return
output
def
_split
(
inputs
,
dim
=-
1
,
group
=
None
):
group_world_size
=
DistributedEnv
.
get_group_world_size
()
rank_in_vae_group
=
DistributedEnv
.
get_rank_in_vae_group
()
height
=
inputs
.
shape
[
3
]
start_idx
=
(
height
+
group_world_size
-
1
)
//
group_world_size
*
rank_in_vae_group
end_idx
=
min
((
height
+
group_world_size
-
1
)
//
group_world_size
*
(
rank_in_vae_group
+
1
),
height
)
return
inputs
[:,
:,
:,
start_idx
:
end_idx
,
:].
clone
()
class
_SplitForwardGatherBackward
(
torch
.
autograd
.
Function
):
"""Split the input.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
dim
,
group
):
ctx
.
group
=
group
ctx
.
dim
=
dim
return
_split
(
inputs
,
dim
,
group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
,
ctx
.
dim
,
ctx
.
group
),
None
,
None
class
_GatherForwardSplitBackward
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatenate.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
dim
,
group
):
ctx
.
group
=
group
ctx
.
dim
=
dim
return
_gather
(
inputs
,
dim
,
group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
,
ctx
.
dim
,
ctx
.
group
),
None
,
None
def
split_forward_gather_backward
(
group
,
inputs
,
dim
):
return
_SplitForwardGatherBackward
.
apply
(
inputs
,
dim
,
group
)
def
gather_forward_split_backward
(
group
,
inputs
,
dim
):
return
_GatherForwardSplitBackward
.
apply
(
inputs
,
dim
,
group
)
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
c93c756c
...
...
@@ -7,8 +7,6 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
loguru
import
logger
from
lightx2v.models.video_encoders.hf.wan.dist.distributed_env
import
DistributedEnv
from
lightx2v.models.video_encoders.hf.wan.dist.split_gather
import
gather_forward_split_backward
,
split_forward_gather_backward
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
...
...
@@ -519,7 +517,6 @@ class WanVAE_(nn.Module):
self
.
temperal_downsample
=
temperal_downsample
self
.
temperal_upsample
=
temperal_downsample
[::
-
1
]
self
.
spatial_compression_ratio
=
2
**
len
(
self
.
temperal_downsample
)
self
.
use_approximate_patch
=
False
# The minimal tile height and width for spatial tiling to be used
self
.
tile_sample_min_height
=
256
...
...
@@ -550,12 +547,6 @@ class WanVAE_(nn.Module):
dropout
,
)
def
enable_approximate_patch
(
self
):
self
.
use_approximate_patch
=
True
def
disable_approximate_patch
(
self
):
self
.
use_approximate_patch
=
False
def
forward
(
self
,
x
):
mu
,
log_var
=
self
.
encode
(
x
)
z
=
self
.
reparameterize
(
mu
,
log_var
)
...
...
@@ -638,9 +629,6 @@ class WanVAE_(nn.Module):
return
enc
def
tiled_decode
(
self
,
z
,
scale
):
if
self
.
use_approximate_patch
:
z
=
split_forward_gather_backward
(
None
,
z
,
3
)
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
...
...
@@ -690,8 +678,6 @@ class WanVAE_(nn.Module):
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
dec
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
sample_height
,
:
sample_width
]
if
self
.
use_approximate_patch
:
dec
=
gather_forward_split_backward
(
None
,
dec
,
3
)
return
dec
...
...
@@ -726,8 +712,6 @@ class WanVAE_(nn.Module):
def
decode
(
self
,
z
,
scale
):
self
.
clear_cache
()
if
self
.
use_approximate_patch
:
z
=
split_forward_gather_backward
(
None
,
z
,
3
)
# z: [b,c,t,h,w]
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
...
...
@@ -752,9 +736,6 @@ class WanVAE_(nn.Module):
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
if
self
.
use_approximate_patch
:
out
=
gather_forward_split_backward
(
None
,
out
,
3
)
self
.
clear_cache
()
return
out
...
...
@@ -866,12 +847,6 @@ class WanVAE:
# init model
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
self
.
use_approximate_patch
=
False
if
self
.
parallel
and
self
.
parallel
.
get
(
"use_patch_vae"
,
False
):
# assert not self.use_tiling
DistributedEnv
.
initialize
(
None
)
self
.
use_approximate_patch
=
True
self
.
model
.
enable_approximate_patch
()
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
...
...
@@ -892,6 +867,70 @@ class WanVAE:
self
.
inv_std
=
self
.
inv_std
.
cuda
()
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
encode_dist
(
self
,
video
,
world_size
,
cur_rank
,
split_dim
):
spatial_ratio
=
8
if
split_dim
==
3
:
total_latent_len
=
video
.
shape
[
3
]
//
spatial_ratio
elif
split_dim
==
4
:
total_latent_len
=
video
.
shape
[
4
]
//
spatial_ratio
else
:
raise
ValueError
(
f
"Unsupported split_dim:
{
split_dim
}
"
)
splited_chunk_len
=
total_latent_len
//
world_size
padding_size
=
1
video_chunk_len
=
splited_chunk_len
*
spatial_ratio
video_padding_len
=
padding_size
*
spatial_ratio
if
cur_rank
==
0
:
if
split_dim
==
3
:
video_chunk
=
video
[:,
:,
:,
:
video_chunk_len
+
2
*
video_padding_len
,
:].
contiguous
()
elif
split_dim
==
4
:
video_chunk
=
video
[:,
:,
:,
:,
:
video_chunk_len
+
2
*
video_padding_len
].
contiguous
()
elif
cur_rank
==
world_size
-
1
:
if
split_dim
==
3
:
video_chunk
=
video
[:,
:,
:,
-
(
video_chunk_len
+
2
*
video_padding_len
)
:,
:].
contiguous
()
elif
split_dim
==
4
:
video_chunk
=
video
[:,
:,
:,
:,
-
(
video_chunk_len
+
2
*
video_padding_len
)
:].
contiguous
()
else
:
start_idx
=
cur_rank
*
video_chunk_len
-
video_padding_len
end_idx
=
(
cur_rank
+
1
)
*
video_chunk_len
+
video_padding_len
if
split_dim
==
3
:
video_chunk
=
video
[:,
:,
:,
start_idx
:
end_idx
,
:].
contiguous
()
elif
split_dim
==
4
:
video_chunk
=
video
[:,
:,
:,
:,
start_idx
:
end_idx
].
contiguous
()
if
self
.
use_tiling
:
encoded_chunk
=
self
.
model
.
tiled_encode
(
video_chunk
,
self
.
scale
).
float
()
else
:
encoded_chunk
=
self
.
model
.
encode
(
video_chunk
,
self
.
scale
).
float
()
if
cur_rank
==
0
:
if
split_dim
==
3
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
:
splited_chunk_len
,
:].
contiguous
()
elif
split_dim
==
4
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
:,
:
splited_chunk_len
].
contiguous
()
elif
cur_rank
==
world_size
-
1
:
if
split_dim
==
3
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
-
splited_chunk_len
:,
:].
contiguous
()
elif
split_dim
==
4
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
:,
-
splited_chunk_len
:].
contiguous
()
else
:
if
split_dim
==
3
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
padding_size
:
-
padding_size
,
:].
contiguous
()
elif
split_dim
==
4
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
:,
padding_size
:
-
padding_size
].
contiguous
()
full_encoded
=
[
torch
.
empty_like
(
encoded_chunk
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
full_encoded
,
encoded_chunk
)
torch
.
cuda
.
synchronize
()
encoded
=
torch
.
cat
(
full_encoded
,
dim
=
split_dim
)
return
encoded
.
squeeze
(
0
)
def
encode
(
self
,
video
):
"""
video: one video with shape [1, C, T, H, W].
...
...
@@ -899,10 +938,27 @@ class WanVAE:
if
self
.
cpu_offload
:
self
.
to_cuda
()
if
self
.
use_tiling
:
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
if
self
.
parallel
:
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
height
,
width
=
video
.
shape
[
3
],
video
.
shape
[
4
]
# Check if dimensions are divisible by world_size
if
width
%
world_size
==
0
:
out
=
self
.
encode_dist
(
video
,
world_size
,
cur_rank
,
split_dim
=
4
)
elif
height
%
world_size
==
0
:
out
=
self
.
encode_dist
(
video
,
world_size
,
cur_rank
,
split_dim
=
3
)
else
:
logger
.
info
(
"Fall back to naive encode mode"
)
if
self
.
use_tiling
:
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
else
:
out
=
self
.
model
.
encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
else
:
out
=
self
.
model
.
encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
if
self
.
use_tiling
:
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
else
:
out
=
self
.
model
.
encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
if
self
.
cpu_offload
:
self
.
to_cpu
()
...
...
@@ -961,7 +1017,7 @@ class WanVAE:
if
self
.
cpu_offload
:
self
.
to_cuda
()
if
self
.
parallel
and
not
self
.
use_approximate_patch
:
if
self
.
parallel
:
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
height
,
width
=
zs
.
shape
[
2
],
zs
.
shape
[
3
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment