Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c93c756c
"vscode:/vscode.git/clone" did not exist on "8a9a607c888bf9d88b798c7ed9c2e781ee26f01c"
Commit
c93c756c
authored
Aug 28, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Aug 28, 2025
Browse files
Support vae encode dist infer & Remove approximate_patch vae for its bad precision. (#255)
parent
bba65ffd
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
197 deletions
+85
-197
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
+0
-0
lightx2v/models/video_encoders/hf/wan/dist/distributed_env.py
...tx2v/models/video_encoders/hf/wan/dist/distributed_env.py
+0
-72
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
+0
-96
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+85
-29
No files found.
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
deleted
100755 → 0
View file @
bba65ffd
lightx2v/models/video_encoders/hf/wan/dist/distributed_env.py
deleted
100755 → 0
View file @
bba65ffd
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/distributed_env.py
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
import
os
class
DistributedEnv
:
_vae_group
=
None
_local_rank
=
None
_world_size
=
None
@
classmethod
def
initialize
(
cls
,
vae_group
:
ProcessGroup
):
if
vae_group
is
None
:
cls
.
_vae_group
=
dist
.
group
.
WORLD
else
:
cls
.
_vae_group
=
vae_group
cls
.
_local_rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
0
))
# FIXME: in ray all local_rank is 0
cls
.
_rank_mapping
=
None
cls
.
_init_rank_mapping
()
@
classmethod
def
get_vae_group
(
cls
)
->
ProcessGroup
:
if
cls
.
_vae_group
is
None
:
raise
RuntimeError
(
"DistributedEnv not initialized. Call initialize() first."
)
return
cls
.
_vae_group
@
classmethod
def
get_global_rank
(
cls
)
->
int
:
return
dist
.
get_rank
()
@
classmethod
def
_init_rank_mapping
(
cls
):
"""Initialize the mapping between group ranks and global ranks"""
if
cls
.
_rank_mapping
is
None
:
# Get all ranks in the group
ranks
=
[
None
]
*
cls
.
get_group_world_size
()
dist
.
all_gather_object
(
ranks
,
cls
.
get_global_rank
(),
group
=
cls
.
get_vae_group
())
cls
.
_rank_mapping
=
ranks
@
classmethod
def
get_global_rank_from_group_rank
(
cls
,
group_rank
:
int
)
->
int
:
"""Convert a rank in VAE group to global rank using cached mapping.
Args:
group_rank: The rank in VAE group
Returns:
The corresponding global rank
Raises:
RuntimeError: If the group_rank is invalid
"""
if
cls
.
_rank_mapping
is
None
:
cls
.
_init_rank_mapping
()
if
group_rank
<
0
or
group_rank
>=
cls
.
get_group_world_size
():
raise
RuntimeError
(
f
"Invalid group rank:
{
group_rank
}
. Must be in range [0,
{
cls
.
get_group_world_size
()
-
1
}
]"
)
return
cls
.
_rank_mapping
[
group_rank
]
@
classmethod
def
get_rank_in_vae_group
(
cls
)
->
int
:
return
dist
.
get_rank
(
cls
.
get_vae_group
())
@
classmethod
def
get_group_world_size
(
cls
)
->
int
:
return
dist
.
get_world_size
(
cls
.
get_vae_group
())
@
classmethod
def
get_local_rank
(
cls
)
->
int
:
return
cls
.
_local_rank
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
deleted
100755 → 0
View file @
bba65ffd
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/split_gather.py
import
torch
import
torch.distributed
as
dist
from
lightx2v.models.video_encoders.hf.wan.dist.distributed_env
import
DistributedEnv
def
_gather
(
patch_hidden_state
,
dim
=-
1
,
group
=
None
):
group_world_size
=
DistributedEnv
.
get_group_world_size
()
local_rank
=
DistributedEnv
.
get_local_rank
()
patch_height_list
=
[
torch
.
empty
([
1
],
dtype
=
torch
.
int64
,
device
=
f
"cuda:
{
local_rank
}
"
)
for
_
in
range
(
group_world_size
)]
dist
.
all_gather
(
patch_height_list
,
torch
.
tensor
(
[
patch_hidden_state
.
shape
[
3
]],
dtype
=
torch
.
int64
,
device
=
f
"cuda:
{
local_rank
}
"
),
group
=
DistributedEnv
.
get_vae_group
()
)
patch_hidden_state_list
=
[
torch
.
zeros
(
[
patch_hidden_state
.
shape
[
0
],
patch_hidden_state
.
shape
[
1
],
patch_hidden_state
.
shape
[
2
],
patch_height_list
[
i
].
item
(),
patch_hidden_state
.
shape
[
4
]],
dtype
=
patch_hidden_state
.
dtype
,
device
=
f
"cuda:
{
local_rank
}
"
,
requires_grad
=
patch_hidden_state
.
requires_grad
)
for
i
in
range
(
group_world_size
)
]
dist
.
all_gather
(
patch_hidden_state_list
,
patch_hidden_state
.
contiguous
(),
group
=
DistributedEnv
.
get_vae_group
()
)
output
=
torch
.
cat
(
patch_hidden_state_list
,
dim
=
3
)
return
output
def
_split
(
inputs
,
dim
=-
1
,
group
=
None
):
group_world_size
=
DistributedEnv
.
get_group_world_size
()
rank_in_vae_group
=
DistributedEnv
.
get_rank_in_vae_group
()
height
=
inputs
.
shape
[
3
]
start_idx
=
(
height
+
group_world_size
-
1
)
//
group_world_size
*
rank_in_vae_group
end_idx
=
min
((
height
+
group_world_size
-
1
)
//
group_world_size
*
(
rank_in_vae_group
+
1
),
height
)
return
inputs
[:,
:,
:,
start_idx
:
end_idx
,
:].
clone
()
class
_SplitForwardGatherBackward
(
torch
.
autograd
.
Function
):
"""Split the input.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
dim
,
group
):
ctx
.
group
=
group
ctx
.
dim
=
dim
return
_split
(
inputs
,
dim
,
group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
,
ctx
.
dim
,
ctx
.
group
),
None
,
None
class
_GatherForwardSplitBackward
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatenate.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
dim
,
group
):
ctx
.
group
=
group
ctx
.
dim
=
dim
return
_gather
(
inputs
,
dim
,
group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
,
ctx
.
dim
,
ctx
.
group
),
None
,
None
def
split_forward_gather_backward
(
group
,
inputs
,
dim
):
return
_SplitForwardGatherBackward
.
apply
(
inputs
,
dim
,
group
)
def
gather_forward_split_backward
(
group
,
inputs
,
dim
):
return
_GatherForwardSplitBackward
.
apply
(
inputs
,
dim
,
group
)
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
c93c756c
...
@@ -7,8 +7,6 @@ import torch.nn.functional as F
...
@@ -7,8 +7,6 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
einops
import
rearrange
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.models.video_encoders.hf.wan.dist.distributed_env
import
DistributedEnv
from
lightx2v.models.video_encoders.hf.wan.dist.split_gather
import
gather_forward_split_backward
,
split_forward_gather_backward
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
__all__
=
[
...
@@ -519,7 +517,6 @@ class WanVAE_(nn.Module):
...
@@ -519,7 +517,6 @@ class WanVAE_(nn.Module):
self
.
temperal_downsample
=
temperal_downsample
self
.
temperal_downsample
=
temperal_downsample
self
.
temperal_upsample
=
temperal_downsample
[::
-
1
]
self
.
temperal_upsample
=
temperal_downsample
[::
-
1
]
self
.
spatial_compression_ratio
=
2
**
len
(
self
.
temperal_downsample
)
self
.
spatial_compression_ratio
=
2
**
len
(
self
.
temperal_downsample
)
self
.
use_approximate_patch
=
False
# The minimal tile height and width for spatial tiling to be used
# The minimal tile height and width for spatial tiling to be used
self
.
tile_sample_min_height
=
256
self
.
tile_sample_min_height
=
256
...
@@ -550,12 +547,6 @@ class WanVAE_(nn.Module):
...
@@ -550,12 +547,6 @@ class WanVAE_(nn.Module):
dropout
,
dropout
,
)
)
def
enable_approximate_patch
(
self
):
self
.
use_approximate_patch
=
True
def
disable_approximate_patch
(
self
):
self
.
use_approximate_patch
=
False
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
mu
,
log_var
=
self
.
encode
(
x
)
mu
,
log_var
=
self
.
encode
(
x
)
z
=
self
.
reparameterize
(
mu
,
log_var
)
z
=
self
.
reparameterize
(
mu
,
log_var
)
...
@@ -638,9 +629,6 @@ class WanVAE_(nn.Module):
...
@@ -638,9 +629,6 @@ class WanVAE_(nn.Module):
return
enc
return
enc
def
tiled_decode
(
self
,
z
,
scale
):
def
tiled_decode
(
self
,
z
,
scale
):
if
self
.
use_approximate_patch
:
z
=
split_forward_gather_backward
(
None
,
z
,
3
)
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
else
:
...
@@ -690,8 +678,6 @@ class WanVAE_(nn.Module):
...
@@ -690,8 +678,6 @@ class WanVAE_(nn.Module):
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
dec
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
sample_height
,
:
sample_width
]
dec
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
sample_height
,
:
sample_width
]
if
self
.
use_approximate_patch
:
dec
=
gather_forward_split_backward
(
None
,
dec
,
3
)
return
dec
return
dec
...
@@ -726,8 +712,6 @@ class WanVAE_(nn.Module):
...
@@ -726,8 +712,6 @@ class WanVAE_(nn.Module):
def
decode
(
self
,
z
,
scale
):
def
decode
(
self
,
z
,
scale
):
self
.
clear_cache
()
self
.
clear_cache
()
if
self
.
use_approximate_patch
:
z
=
split_forward_gather_backward
(
None
,
z
,
3
)
# z: [b,c,t,h,w]
# z: [b,c,t,h,w]
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
...
@@ -752,9 +736,6 @@ class WanVAE_(nn.Module):
...
@@ -752,9 +736,6 @@ class WanVAE_(nn.Module):
)
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
if
self
.
use_approximate_patch
:
out
=
gather_forward_split_backward
(
None
,
out
,
3
)
self
.
clear_cache
()
self
.
clear_cache
()
return
out
return
out
...
@@ -866,12 +847,6 @@ class WanVAE:
...
@@ -866,12 +847,6 @@ class WanVAE:
# init model
# init model
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
self
.
use_approximate_patch
=
False
if
self
.
parallel
and
self
.
parallel
.
get
(
"use_patch_vae"
,
False
):
# assert not self.use_tiling
DistributedEnv
.
initialize
(
None
)
self
.
use_approximate_patch
=
True
self
.
model
.
enable_approximate_patch
()
def
current_device
(
self
):
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
return
next
(
self
.
model
.
parameters
()).
device
...
@@ -892,6 +867,70 @@ class WanVAE:
...
@@ -892,6 +867,70 @@ class WanVAE:
self
.
inv_std
=
self
.
inv_std
.
cuda
()
self
.
inv_std
=
self
.
inv_std
.
cuda
()
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
encode_dist
(
self
,
video
,
world_size
,
cur_rank
,
split_dim
):
spatial_ratio
=
8
if
split_dim
==
3
:
total_latent_len
=
video
.
shape
[
3
]
//
spatial_ratio
elif
split_dim
==
4
:
total_latent_len
=
video
.
shape
[
4
]
//
spatial_ratio
else
:
raise
ValueError
(
f
"Unsupported split_dim:
{
split_dim
}
"
)
splited_chunk_len
=
total_latent_len
//
world_size
padding_size
=
1
video_chunk_len
=
splited_chunk_len
*
spatial_ratio
video_padding_len
=
padding_size
*
spatial_ratio
if
cur_rank
==
0
:
if
split_dim
==
3
:
video_chunk
=
video
[:,
:,
:,
:
video_chunk_len
+
2
*
video_padding_len
,
:].
contiguous
()
elif
split_dim
==
4
:
video_chunk
=
video
[:,
:,
:,
:,
:
video_chunk_len
+
2
*
video_padding_len
].
contiguous
()
elif
cur_rank
==
world_size
-
1
:
if
split_dim
==
3
:
video_chunk
=
video
[:,
:,
:,
-
(
video_chunk_len
+
2
*
video_padding_len
)
:,
:].
contiguous
()
elif
split_dim
==
4
:
video_chunk
=
video
[:,
:,
:,
:,
-
(
video_chunk_len
+
2
*
video_padding_len
)
:].
contiguous
()
else
:
start_idx
=
cur_rank
*
video_chunk_len
-
video_padding_len
end_idx
=
(
cur_rank
+
1
)
*
video_chunk_len
+
video_padding_len
if
split_dim
==
3
:
video_chunk
=
video
[:,
:,
:,
start_idx
:
end_idx
,
:].
contiguous
()
elif
split_dim
==
4
:
video_chunk
=
video
[:,
:,
:,
:,
start_idx
:
end_idx
].
contiguous
()
if
self
.
use_tiling
:
encoded_chunk
=
self
.
model
.
tiled_encode
(
video_chunk
,
self
.
scale
).
float
()
else
:
encoded_chunk
=
self
.
model
.
encode
(
video_chunk
,
self
.
scale
).
float
()
if
cur_rank
==
0
:
if
split_dim
==
3
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
:
splited_chunk_len
,
:].
contiguous
()
elif
split_dim
==
4
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
:,
:
splited_chunk_len
].
contiguous
()
elif
cur_rank
==
world_size
-
1
:
if
split_dim
==
3
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
-
splited_chunk_len
:,
:].
contiguous
()
elif
split_dim
==
4
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
:,
-
splited_chunk_len
:].
contiguous
()
else
:
if
split_dim
==
3
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
padding_size
:
-
padding_size
,
:].
contiguous
()
elif
split_dim
==
4
:
encoded_chunk
=
encoded_chunk
[:,
:,
:,
:,
padding_size
:
-
padding_size
].
contiguous
()
full_encoded
=
[
torch
.
empty_like
(
encoded_chunk
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
full_encoded
,
encoded_chunk
)
torch
.
cuda
.
synchronize
()
encoded
=
torch
.
cat
(
full_encoded
,
dim
=
split_dim
)
return
encoded
.
squeeze
(
0
)
def
encode
(
self
,
video
):
def
encode
(
self
,
video
):
"""
"""
video: one video with shape [1, C, T, H, W].
video: one video with shape [1, C, T, H, W].
...
@@ -899,6 +938,23 @@ class WanVAE:
...
@@ -899,6 +938,23 @@ class WanVAE:
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
if
self
.
parallel
:
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
height
,
width
=
video
.
shape
[
3
],
video
.
shape
[
4
]
# Check if dimensions are divisible by world_size
if
width
%
world_size
==
0
:
out
=
self
.
encode_dist
(
video
,
world_size
,
cur_rank
,
split_dim
=
4
)
elif
height
%
world_size
==
0
:
out
=
self
.
encode_dist
(
video
,
world_size
,
cur_rank
,
split_dim
=
3
)
else
:
logger
.
info
(
"Fall back to naive encode mode"
)
if
self
.
use_tiling
:
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
else
:
out
=
self
.
model
.
encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
else
:
if
self
.
use_tiling
:
if
self
.
use_tiling
:
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
float
().
squeeze
(
0
)
else
:
else
:
...
@@ -961,7 +1017,7 @@ class WanVAE:
...
@@ -961,7 +1017,7 @@ class WanVAE:
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
if
self
.
parallel
and
not
self
.
use_approximate_patch
:
if
self
.
parallel
:
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
cur_rank
=
dist
.
get_rank
()
height
,
width
=
zs
.
shape
[
2
],
zs
.
shape
[
3
]
height
,
width
=
zs
.
shape
[
2
],
zs
.
shape
[
3
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment