Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
9908528b
Commit
9908528b
authored
Aug 20, 2025
by
gushiqiao
Browse files
[fea] Update patch vae
parent
d061ae81
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
0 deletions
+168
-0
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
+0
-0
lightx2v/models/video_encoders/hf/wan/dist/distributed_env.py
...tx2v/models/video_encoders/hf/wan/dist/distributed_env.py
+72
-0
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
+96
-0
No files found.
lightx2v/models/video_encoders/hf/wan/dist/__init__.py
0 → 100755
View file @
9908528b
lightx2v/models/video_encoders/hf/wan/dist/distributed_env.py
0 → 100755
View file @
9908528b
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/distributed_env.py
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
import
os
class
DistributedEnv
:
_vae_group
=
None
_local_rank
=
None
_world_size
=
None
@
classmethod
def
initialize
(
cls
,
vae_group
:
ProcessGroup
):
if
vae_group
is
None
:
cls
.
_vae_group
=
dist
.
group
.
WORLD
else
:
cls
.
_vae_group
=
vae_group
cls
.
_local_rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
0
))
# FIXME: in ray all local_rank is 0
cls
.
_rank_mapping
=
None
cls
.
_init_rank_mapping
()
@
classmethod
def
get_vae_group
(
cls
)
->
ProcessGroup
:
if
cls
.
_vae_group
is
None
:
raise
RuntimeError
(
"DistributedEnv not initialized. Call initialize() first."
)
return
cls
.
_vae_group
@
classmethod
def
get_global_rank
(
cls
)
->
int
:
return
dist
.
get_rank
()
@
classmethod
def
_init_rank_mapping
(
cls
):
"""Initialize the mapping between group ranks and global ranks"""
if
cls
.
_rank_mapping
is
None
:
# Get all ranks in the group
ranks
=
[
None
]
*
cls
.
get_group_world_size
()
dist
.
all_gather_object
(
ranks
,
cls
.
get_global_rank
(),
group
=
cls
.
get_vae_group
())
cls
.
_rank_mapping
=
ranks
@
classmethod
def
get_global_rank_from_group_rank
(
cls
,
group_rank
:
int
)
->
int
:
"""Convert a rank in VAE group to global rank using cached mapping.
Args:
group_rank: The rank in VAE group
Returns:
The corresponding global rank
Raises:
RuntimeError: If the group_rank is invalid
"""
if
cls
.
_rank_mapping
is
None
:
cls
.
_init_rank_mapping
()
if
group_rank
<
0
or
group_rank
>=
cls
.
get_group_world_size
():
raise
RuntimeError
(
f
"Invalid group rank:
{
group_rank
}
. Must be in range [0,
{
cls
.
get_group_world_size
()
-
1
}
]"
)
return
cls
.
_rank_mapping
[
group_rank
]
@
classmethod
def
get_rank_in_vae_group
(
cls
)
->
int
:
return
dist
.
get_rank
(
cls
.
get_vae_group
())
@
classmethod
def
get_group_world_size
(
cls
)
->
int
:
return
dist
.
get_world_size
(
cls
.
get_vae_group
())
@
classmethod
def
get_local_rank
(
cls
)
->
int
:
return
cls
.
_local_rank
lightx2v/models/video_encoders/hf/wan/dist/split_gather.py
0 → 100755
View file @
9908528b
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/split_gather.py
import
torch
import
torch.distributed
as
dist
from
lightx2v.models.video_encoders.hf.wan.dist.distributed_env
import
DistributedEnv
def
_gather
(
patch_hidden_state
,
dim
=-
1
,
group
=
None
):
group_world_size
=
DistributedEnv
.
get_group_world_size
()
local_rank
=
DistributedEnv
.
get_local_rank
()
patch_height_list
=
[
torch
.
empty
([
1
],
dtype
=
torch
.
int64
,
device
=
f
"cuda:
{
local_rank
}
"
)
for
_
in
range
(
group_world_size
)]
dist
.
all_gather
(
patch_height_list
,
torch
.
tensor
(
[
patch_hidden_state
.
shape
[
3
]],
dtype
=
torch
.
int64
,
device
=
f
"cuda:
{
local_rank
}
"
),
group
=
DistributedEnv
.
get_vae_group
()
)
patch_hidden_state_list
=
[
torch
.
zeros
(
[
patch_hidden_state
.
shape
[
0
],
patch_hidden_state
.
shape
[
1
],
patch_hidden_state
.
shape
[
2
],
patch_height_list
[
i
].
item
(),
patch_hidden_state
.
shape
[
4
]],
dtype
=
patch_hidden_state
.
dtype
,
device
=
f
"cuda:
{
local_rank
}
"
,
requires_grad
=
patch_hidden_state
.
requires_grad
)
for
i
in
range
(
group_world_size
)
]
dist
.
all_gather
(
patch_hidden_state_list
,
patch_hidden_state
.
contiguous
(),
group
=
DistributedEnv
.
get_vae_group
()
)
output
=
torch
.
cat
(
patch_hidden_state_list
,
dim
=
3
)
return
output
def
_split
(
inputs
,
dim
=-
1
,
group
=
None
):
group_world_size
=
DistributedEnv
.
get_group_world_size
()
rank_in_vae_group
=
DistributedEnv
.
get_rank_in_vae_group
()
height
=
inputs
.
shape
[
3
]
start_idx
=
(
height
+
group_world_size
-
1
)
//
group_world_size
*
rank_in_vae_group
end_idx
=
min
((
height
+
group_world_size
-
1
)
//
group_world_size
*
(
rank_in_vae_group
+
1
),
height
)
return
inputs
[:,
:,
:,
start_idx
:
end_idx
,
:].
clone
()
class
_SplitForwardGatherBackward
(
torch
.
autograd
.
Function
):
"""Split the input.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
dim
,
group
):
ctx
.
group
=
group
ctx
.
dim
=
dim
return
_split
(
inputs
,
dim
,
group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
,
ctx
.
dim
,
ctx
.
group
),
None
,
None
class
_GatherForwardSplitBackward
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatenate.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
dim
,
group
):
ctx
.
group
=
group
ctx
.
dim
=
dim
return
_gather
(
inputs
,
dim
,
group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
,
ctx
.
dim
,
ctx
.
group
),
None
,
None
def
split_forward_gather_backward
(
group
,
inputs
,
dim
):
return
_SplitForwardGatherBackward
.
apply
(
inputs
,
dim
,
group
)
def
gather_forward_split_backward
(
group
,
inputs
,
dim
):
return
_GatherForwardSplitBackward
.
apply
(
inputs
,
dim
,
group
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment