Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e066bad2
Commit
e066bad2
authored
Sep 02, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Sep 02, 2025
Browse files
Support vae 2d-grid dist infer & Rewrite FramePreprocessor using torch (#279)
parent
066d7f19
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
326 additions
and
49 deletions
+326
-49
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+30
-30
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+0
-3
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+296
-16
No files found.
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
e066bad2
...
@@ -162,7 +162,7 @@ class AudioSegment:
...
@@ -162,7 +162,7 @@ class AudioSegment:
useful_length
:
Optional
[
int
]
=
None
useful_length
:
Optional
[
int
]
=
None
class
FramePreprocessor
:
class
FramePreprocessor
TorchVersion
:
"""Handles frame preprocessing including noise and masking"""
"""Handles frame preprocessing including noise and masking"""
def
__init__
(
self
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
mask_rate
:
float
=
0.1
):
def
__init__
(
self
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
mask_rate
:
float
=
0.1
):
...
@@ -170,40 +170,39 @@ class FramePreprocessor:
...
@@ -170,40 +170,39 @@ class FramePreprocessor:
self
.
noise_std
=
noise_std
self
.
noise_std
=
noise_std
self
.
mask_rate
=
mask_rate
self
.
mask_rate
=
mask_rate
def
add_noise
(
self
,
frames
:
np
.
ndarray
,
rnd_state
:
Optional
[
np
.
random
.
RandomState
]
=
None
)
->
np
.
ndarray
:
def
add_noise
(
self
,
frames
:
torch
.
Tensor
,
generator
:
Optional
[
torch
.
Generator
]
=
None
)
->
torch
.
Tensor
:
"""Add noise to frames"""
"""Add noise to frames"""
if
self
.
noise_mean
is
None
or
self
.
noise_std
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
device
=
frames
.
device
shape
=
frames
.
shape
shape
=
frames
.
shape
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
sigma
=
rnd_state
.
normal
(
loc
=
self
.
noise_mean
,
scale
=
self
.
noise_std
,
size
=
(
bs
,))
sigma
=
np
.
exp
(
sigma
)
# Generate sigma values on the same device
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
sigma
=
torch
.
normal
(
mean
=
self
.
noise_mean
,
std
=
self
.
noise_std
,
size
=
(
bs
,),
device
=
device
,
generator
=
generator
)
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
sigma
=
torch
.
exp
(
sigma
)
for
_
in
range
(
1
,
len
(
shape
)):
sigma
=
sigma
.
unsqueeze
(
-
1
)
# Generate noise on the same device
noise
=
torch
.
randn
(
*
shape
,
device
=
device
,
generator
=
generator
)
*
sigma
return
frames
+
noise
return
frames
+
noise
def
add_mask
(
self
,
frames
:
np
.
ndarray
,
rnd_state
:
Optional
[
np
.
random
.
RandomState
]
=
None
)
->
np
.
ndarray
:
def
add_mask
(
self
,
frames
:
torch
.
Tensor
,
generator
:
Optional
[
torch
.
Generator
]
=
None
)
->
torch
.
Tensor
:
"""Add mask to frames"""
"""Add mask to frames"""
if
self
.
mask_rate
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
device
=
frames
.
device
h
,
w
=
frames
.
shape
[
-
2
:]
h
,
w
=
frames
.
shape
[
-
2
:]
mask
=
rnd_state
.
rand
(
h
,
w
)
>
self
.
mask_rate
# Generate mask on the same device
mask
=
torch
.
rand
(
h
,
w
,
device
=
device
,
generator
=
generator
)
>
self
.
mask_rate
return
frames
*
mask
return
frames
*
mask
def
process_prev_frames
(
self
,
frames
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
process_prev_frames
(
self
,
frames
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Process previous frames with noise and masking"""
"""Process previous frames with noise and masking"""
frames_np
=
frames
.
cpu
().
detach
().
numpy
()
frames
=
self
.
add_noise
(
frames
,
torch
.
Generator
(
device
=
frames
.
device
))
frames_np
=
self
.
add_noise
(
frames_np
)
frames
=
self
.
add_mask
(
frames
,
torch
.
Generator
(
device
=
frames
.
device
))
frames_np
=
self
.
add_mask
(
frames_np
)
return
frames
return
torch
.
from_numpy
(
frames_np
).
to
(
dtype
=
frames
.
dtype
,
device
=
frames
.
device
)
class
AudioProcessor
:
class
AudioProcessor
:
...
@@ -283,8 +282,8 @@ class AudioProcessor:
...
@@ -283,8 +282,8 @@ class AudioProcessor:
class
WanAudioRunner
(
WanRunner
):
# type:ignore
class
WanAudioRunner
(
WanRunner
):
# type:ignore
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
frame_preprocessor
=
FramePreprocessor
()
self
.
prev_frame_length
=
self
.
config
.
get
(
"prev_frame_length"
,
5
)
self
.
prev_frame_length
=
self
.
config
.
get
(
"prev_frame_length"
,
5
)
self
.
frame_preprocessor
=
FramePreprocessorTorchVersion
()
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
"""Initialize consistency model scheduler"""
"""Initialize consistency model scheduler"""
...
@@ -399,14 +398,15 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -399,14 +398,15 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
vae_encoder
=
self
.
load_vae_encoder
()
self
.
vae_encoder
=
self
.
load_vae_encoder
()
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
with
ProfilingContext4Debug
(
"vae_encoder in init run segment"
):
if
prev_video
is
not
None
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
if
prev_video
is
not
None
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
else
:
prev_latents
=
None
prev_mask
=
self
.
model
.
scheduler
.
mask
else
:
else
:
prev_latents
=
None
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
prev_mask
=
self
.
model
.
scheduler
.
mask
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
frames_n
=
(
nframe
-
1
)
*
4
+
1
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
e066bad2
import
gc
import
math
import
math
import
numpy
as
np
import
numpy
as
np
...
@@ -99,8 +98,6 @@ class EulerScheduler(WanScheduler):
...
@@ -99,8 +98,6 @@ class EulerScheduler(WanScheduler):
self
.
prev_latents
=
previmg_encoder_output
[
"prev_latents"
]
self
.
prev_latents
=
previmg_encoder_output
[
"prev_latents"
]
self
.
prev_len
=
previmg_encoder_output
[
"prev_len"
]
self
.
prev_len
=
previmg_encoder_output
[
"prev_len"
]
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
unsqueeze_to_ndim
(
self
,
in_tensor
,
tgt_n_dim
):
def
unsqueeze_to_ndim
(
self
,
in_tensor
,
tgt_n_dim
):
if
in_tensor
.
ndim
>
tgt_n_dim
:
if
in_tensor
.
ndim
>
tgt_n_dim
:
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
e066bad2
...
@@ -801,12 +801,14 @@ class WanVAE:
...
@@ -801,12 +801,14 @@ class WanVAE:
parallel
=
False
,
parallel
=
False
,
use_tiling
=
False
,
use_tiling
=
False
,
cpu_offload
=
False
,
cpu_offload
=
False
,
use_2d_split
=
True
,
):
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
parallel
=
parallel
self
.
parallel
=
parallel
self
.
use_tiling
=
use_tiling
self
.
use_tiling
=
use_tiling
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
use_2d_split
=
use_2d_split
mean
=
[
mean
=
[
-
0.7571
,
-
0.7571
,
...
@@ -848,9 +850,68 @@ class WanVAE:
...
@@ -848,9 +850,68 @@ class WanVAE:
self
.
inv_std
=
1.0
/
torch
.
tensor
(
std
,
dtype
=
dtype
,
device
=
device
)
self
.
inv_std
=
1.0
/
torch
.
tensor
(
std
,
dtype
=
dtype
,
device
=
device
)
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# (height, width, world_size) -> (world_size_h, world_size_w)
self
.
grid_table
=
{
# world_size = 2
(
60
,
104
,
2
):
(
1
,
2
),
(
68
,
120
,
2
):
(
1
,
2
),
(
90
,
160
,
2
):
(
1
,
2
),
(
60
,
60
,
2
):
(
1
,
2
),
(
72
,
72
,
2
):
(
1
,
2
),
(
88
,
88
,
2
):
(
1
,
2
),
(
120
,
120
,
2
):
(
1
,
2
),
(
104
,
60
,
2
):
(
2
,
1
),
(
120
,
68
,
2
):
(
2
,
1
),
(
160
,
90
,
2
):
(
2
,
1
),
# world_size = 4
(
60
,
104
,
4
):
(
2
,
2
),
(
68
,
120
,
4
):
(
2
,
2
),
(
90
,
160
,
4
):
(
2
,
2
),
(
60
,
60
,
4
):
(
2
,
2
),
(
72
,
72
,
4
):
(
2
,
2
),
(
88
,
88
,
4
):
(
2
,
2
),
(
120
,
120
,
4
):
(
2
,
2
),
(
104
,
60
,
4
):
(
2
,
2
),
(
120
,
68
,
4
):
(
2
,
2
),
(
160
,
90
,
4
):
(
2
,
2
),
# world_size = 8
(
60
,
104
,
8
):
(
2
,
4
),
(
68
,
120
,
8
):
(
2
,
4
),
(
90
,
160
,
8
):
(
2
,
4
),
(
60
,
60
,
8
):
(
2
,
4
),
(
72
,
72
,
8
):
(
2
,
4
),
(
88
,
88
,
8
):
(
2
,
4
),
(
120
,
120
,
8
):
(
2
,
4
),
(
104
,
60
,
8
):
(
4
,
2
),
(
120
,
68
,
8
):
(
4
,
2
),
(
160
,
90
,
8
):
(
4
,
2
),
}
# init model
# init model
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
,
dtype
=
dtype
).
eval
().
requires_grad_
(
False
).
to
(
device
).
to
(
dtype
)
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
,
dtype
=
dtype
).
eval
().
requires_grad_
(
False
).
to
(
device
).
to
(
dtype
)
def
_calculate_2d_grid
(
self
,
latent_height
,
latent_width
,
world_size
):
if
(
latent_height
,
latent_width
,
world_size
)
in
self
.
grid_table
:
best_h
,
best_w
=
self
.
grid_table
[(
latent_height
,
latent_width
,
world_size
)]
logger
.
info
(
f
"Vae using cached 2D grid:
{
best_h
}
x
{
best_w
}
grid for
{
latent_height
}
x
{
latent_width
}
latent"
)
return
best_h
,
best_w
best_h
,
best_w
=
1
,
world_size
min_aspect_diff
=
float
(
"inf"
)
for
h
in
range
(
1
,
world_size
+
1
):
if
world_size
%
h
==
0
:
w
=
world_size
//
h
if
latent_height
%
h
==
0
and
latent_width
%
w
==
0
:
# Calculate how close this grid is to square
aspect_diff
=
abs
((
latent_height
/
h
)
-
(
latent_width
/
w
))
if
aspect_diff
<
min_aspect_diff
:
min_aspect_diff
=
aspect_diff
best_h
,
best_w
=
h
,
w
logger
.
info
(
f
"Vae using 2D grid & Update cache:
{
best_h
}
x
{
best_w
}
grid for
{
latent_height
}
x
{
latent_width
}
latent"
)
self
.
grid_table
[(
latent_height
,
latent_width
,
world_size
)]
=
(
best_h
,
best_w
)
return
best_h
,
best_w
def
current_device
(
self
):
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
return
next
(
self
.
model
.
parameters
()).
device
...
@@ -934,6 +995,97 @@ class WanVAE:
...
@@ -934,6 +995,97 @@ class WanVAE:
return
encoded
.
squeeze
(
0
)
return
encoded
.
squeeze
(
0
)
def
encode_dist_2d
(
self
,
video
,
world_size_h
,
world_size_w
,
cur_rank_h
,
cur_rank_w
):
spatial_ratio
=
8
# Calculate chunk sizes for both dimensions
total_latent_h
=
video
.
shape
[
3
]
//
spatial_ratio
total_latent_w
=
video
.
shape
[
4
]
//
spatial_ratio
chunk_h
=
total_latent_h
//
world_size_h
chunk_w
=
total_latent_w
//
world_size_w
padding_size
=
1
video_chunk_h
=
chunk_h
*
spatial_ratio
video_chunk_w
=
chunk_w
*
spatial_ratio
video_padding_h
=
padding_size
*
spatial_ratio
video_padding_w
=
padding_size
*
spatial_ratio
# Calculate H dimension slice
if
cur_rank_h
==
0
:
h_start
=
0
h_end
=
video_chunk_h
+
2
*
video_padding_h
elif
cur_rank_h
==
world_size_h
-
1
:
h_start
=
video
.
shape
[
3
]
-
(
video_chunk_h
+
2
*
video_padding_h
)
h_end
=
video
.
shape
[
3
]
else
:
h_start
=
cur_rank_h
*
video_chunk_h
-
video_padding_h
h_end
=
(
cur_rank_h
+
1
)
*
video_chunk_h
+
video_padding_h
# Calculate W dimension slice
if
cur_rank_w
==
0
:
w_start
=
0
w_end
=
video_chunk_w
+
2
*
video_padding_w
elif
cur_rank_w
==
world_size_w
-
1
:
w_start
=
video
.
shape
[
4
]
-
(
video_chunk_w
+
2
*
video_padding_w
)
w_end
=
video
.
shape
[
4
]
else
:
w_start
=
cur_rank_w
*
video_chunk_w
-
video_padding_w
w_end
=
(
cur_rank_w
+
1
)
*
video_chunk_w
+
video_padding_w
# Extract the video chunk for this process
video_chunk
=
video
[:,
:,
:,
h_start
:
h_end
,
w_start
:
w_end
].
contiguous
()
# Encode the chunk
if
self
.
use_tiling
:
encoded_chunk
=
self
.
model
.
tiled_encode
(
video_chunk
,
self
.
scale
)
else
:
encoded_chunk
=
self
.
model
.
encode
(
video_chunk
,
self
.
scale
)
# Remove padding from encoded chunk
if
cur_rank_h
==
0
:
encoded_h_start
=
0
encoded_h_end
=
chunk_h
elif
cur_rank_h
==
world_size_h
-
1
:
encoded_h_start
=
encoded_chunk
.
shape
[
3
]
-
chunk_h
encoded_h_end
=
encoded_chunk
.
shape
[
3
]
else
:
encoded_h_start
=
padding_size
encoded_h_end
=
encoded_chunk
.
shape
[
3
]
-
padding_size
if
cur_rank_w
==
0
:
encoded_w_start
=
0
encoded_w_end
=
chunk_w
elif
cur_rank_w
==
world_size_w
-
1
:
encoded_w_start
=
encoded_chunk
.
shape
[
4
]
-
chunk_w
encoded_w_end
=
encoded_chunk
.
shape
[
4
]
else
:
encoded_w_start
=
padding_size
encoded_w_end
=
encoded_chunk
.
shape
[
4
]
-
padding_size
encoded_chunk
=
encoded_chunk
[:,
:,
:,
encoded_h_start
:
encoded_h_end
,
encoded_w_start
:
encoded_w_end
].
contiguous
()
# Gather all chunks
total_processes
=
world_size_h
*
world_size_w
full_encoded
=
[
torch
.
empty_like
(
encoded_chunk
)
for
_
in
range
(
total_processes
)]
dist
.
all_gather
(
full_encoded
,
encoded_chunk
)
torch
.
cuda
.
synchronize
()
# Reconstruct the full encoded tensor
encoded_rows
=
[]
for
h_idx
in
range
(
world_size_h
):
encoded_cols
=
[]
for
w_idx
in
range
(
world_size_w
):
process_idx
=
h_idx
*
world_size_w
+
w_idx
encoded_cols
.
append
(
full_encoded
[
process_idx
])
encoded_rows
.
append
(
torch
.
cat
(
encoded_cols
,
dim
=
4
))
encoded
=
torch
.
cat
(
encoded_rows
,
dim
=
3
)
return
encoded
.
squeeze
(
0
)
def
encode
(
self
,
video
):
def
encode
(
self
,
video
):
"""
"""
video: one video with shape [1, C, T, H, W].
video: one video with shape [1, C, T, H, W].
...
@@ -946,17 +1098,23 @@ class WanVAE:
...
@@ -946,17 +1098,23 @@ class WanVAE:
cur_rank
=
dist
.
get_rank
()
cur_rank
=
dist
.
get_rank
()
height
,
width
=
video
.
shape
[
3
],
video
.
shape
[
4
]
height
,
width
=
video
.
shape
[
3
],
video
.
shape
[
4
]
# Check if dimensions are divisible by world_size
if
self
.
use_2d_split
:
if
width
%
world_size
==
0
:
world_size_h
,
world_size_w
=
self
.
_calculate_2d_grid
(
height
//
8
,
width
//
8
,
world_size
)
out
=
self
.
encode_dist
(
video
,
world_size
,
cur_rank
,
split_dim
=
4
)
cur_rank_h
=
cur_rank
//
world_size_w
elif
height
%
world_size
==
0
:
cur_rank_w
=
cur_rank
%
world_size
_w
out
=
self
.
encode_dist
(
video
,
world_size
,
cur_rank
,
split_dim
=
3
)
out
=
self
.
encode_dist
_2d
(
video
,
world_size
_h
,
world_size_w
,
cur_rank_h
,
cur_rank_w
)
else
:
else
:
logger
.
info
(
"Fall back to naive encode mode"
)
# Original 1D splitting logic
if
self
.
use_tiling
:
if
width
%
world_size
==
0
:
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
squeeze
(
0
)
out
=
self
.
encode_dist
(
video
,
world_size
,
cur_rank
,
split_dim
=
4
)
elif
height
%
world_size
==
0
:
out
=
self
.
encode_dist
(
video
,
world_size
,
cur_rank
,
split_dim
=
3
)
else
:
else
:
out
=
self
.
model
.
encode
(
video
,
self
.
scale
).
squeeze
(
0
)
logger
.
info
(
"Fall back to naive encode mode"
)
if
self
.
use_tiling
:
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
squeeze
(
0
)
else
:
out
=
self
.
model
.
encode
(
video
,
self
.
scale
).
squeeze
(
0
)
else
:
else
:
if
self
.
use_tiling
:
if
self
.
use_tiling
:
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
squeeze
(
0
)
out
=
self
.
model
.
tiled_encode
(
video
,
self
.
scale
).
squeeze
(
0
)
...
@@ -1016,6 +1174,89 @@ class WanVAE:
...
@@ -1016,6 +1174,89 @@ class WanVAE:
return
images
return
images
def
decode_dist_2d
(
self
,
zs
,
world_size_h
,
world_size_w
,
cur_rank_h
,
cur_rank_w
):
total_h
=
zs
.
shape
[
2
]
total_w
=
zs
.
shape
[
3
]
chunk_h
=
total_h
//
world_size_h
chunk_w
=
total_w
//
world_size_w
padding_size
=
1
# Calculate H dimension slice
if
cur_rank_h
==
0
:
h_start
=
0
h_end
=
chunk_h
+
2
*
padding_size
elif
cur_rank_h
==
world_size_h
-
1
:
h_start
=
total_h
-
(
chunk_h
+
2
*
padding_size
)
h_end
=
total_h
else
:
h_start
=
cur_rank_h
*
chunk_h
-
padding_size
h_end
=
(
cur_rank_h
+
1
)
*
chunk_h
+
padding_size
# Calculate W dimension slice
if
cur_rank_w
==
0
:
w_start
=
0
w_end
=
chunk_w
+
2
*
padding_size
elif
cur_rank_w
==
world_size_w
-
1
:
w_start
=
total_w
-
(
chunk_w
+
2
*
padding_size
)
w_end
=
total_w
else
:
w_start
=
cur_rank_w
*
chunk_w
-
padding_size
w_end
=
(
cur_rank_w
+
1
)
*
chunk_w
+
padding_size
# Extract the latent chunk for this process
zs_chunk
=
zs
[:,
:,
h_start
:
h_end
,
w_start
:
w_end
].
contiguous
()
# Decode the chunk
decode_func
=
self
.
model
.
tiled_decode
if
self
.
use_tiling
else
self
.
model
.
decode
images_chunk
=
decode_func
(
zs_chunk
.
unsqueeze
(
0
),
self
.
scale
).
clamp_
(
-
1
,
1
)
# Remove padding from decoded chunk
spatial_ratio
=
8
if
cur_rank_h
==
0
:
decoded_h_start
=
0
decoded_h_end
=
chunk_h
*
spatial_ratio
elif
cur_rank_h
==
world_size_h
-
1
:
decoded_h_start
=
images_chunk
.
shape
[
3
]
-
chunk_h
*
spatial_ratio
decoded_h_end
=
images_chunk
.
shape
[
3
]
else
:
decoded_h_start
=
padding_size
*
spatial_ratio
decoded_h_end
=
images_chunk
.
shape
[
3
]
-
padding_size
*
spatial_ratio
if
cur_rank_w
==
0
:
decoded_w_start
=
0
decoded_w_end
=
chunk_w
*
spatial_ratio
elif
cur_rank_w
==
world_size_w
-
1
:
decoded_w_start
=
images_chunk
.
shape
[
4
]
-
chunk_w
*
spatial_ratio
decoded_w_end
=
images_chunk
.
shape
[
4
]
else
:
decoded_w_start
=
padding_size
*
spatial_ratio
decoded_w_end
=
images_chunk
.
shape
[
4
]
-
padding_size
*
spatial_ratio
images_chunk
=
images_chunk
[:,
:,
:,
decoded_h_start
:
decoded_h_end
,
decoded_w_start
:
decoded_w_end
].
contiguous
()
# Gather all chunks
total_processes
=
world_size_h
*
world_size_w
full_images
=
[
torch
.
empty_like
(
images_chunk
)
for
_
in
range
(
total_processes
)]
dist
.
all_gather
(
full_images
,
images_chunk
)
torch
.
cuda
.
synchronize
()
# Reconstruct the full image tensor
image_rows
=
[]
for
h_idx
in
range
(
world_size_h
):
image_cols
=
[]
for
w_idx
in
range
(
world_size_w
):
process_idx
=
h_idx
*
world_size_w
+
w_idx
image_cols
.
append
(
full_images
[
process_idx
])
image_rows
.
append
(
torch
.
cat
(
image_cols
,
dim
=
4
))
images
=
torch
.
cat
(
image_rows
,
dim
=
3
)
return
images
def
decode
(
self
,
zs
):
def
decode
(
self
,
zs
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
...
@@ -1023,15 +1264,22 @@ class WanVAE:
...
@@ -1023,15 +1264,22 @@ class WanVAE:
if
self
.
parallel
:
if
self
.
parallel
:
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
cur_rank
=
dist
.
get_rank
()
height
,
width
=
zs
.
shape
[
2
],
zs
.
shape
[
3
]
latent_
height
,
latent_
width
=
zs
.
shape
[
2
],
zs
.
shape
[
3
]
if
width
%
world_size
==
0
:
if
self
.
use_2d_split
:
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
=
3
)
world_size_h
,
world_size_w
=
self
.
_calculate_2d_grid
(
latent_height
,
latent_width
,
world_size
)
elif
height
%
world_size
==
0
:
cur_rank_h
=
cur_rank
//
world_size_w
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
=
2
)
cur_rank_w
=
cur_rank
%
world_size_w
images
=
self
.
decode_dist_2d
(
zs
,
world_size_h
,
world_size_w
,
cur_rank_h
,
cur_rank_w
)
else
:
else
:
logger
.
info
(
"Fall back to naive decode mode"
)
# Original 1D splitting logic
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
clamp_
(
-
1
,
1
)
if
latent_width
%
world_size
==
0
:
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
=
3
)
elif
latent_height
%
world_size
==
0
:
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
=
2
)
else
:
logger
.
info
(
"Fall back to naive decode mode"
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
clamp_
(
-
1
,
1
)
else
:
else
:
decode_func
=
self
.
model
.
tiled_decode
if
self
.
use_tiling
else
self
.
model
.
decode
decode_func
=
self
.
model
.
tiled_decode
if
self
.
use_tiling
else
self
.
model
.
decode
images
=
decode_func
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
clamp_
(
-
1
,
1
)
images
=
decode_func
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
clamp_
(
-
1
,
1
)
...
@@ -1041,3 +1289,35 @@ class WanVAE:
...
@@ -1041,3 +1289,35 @@ class WanVAE:
self
.
to_cpu
()
self
.
to_cpu
()
return
images
return
images
if
__name__
==
"__main__"
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
# # Test both 1D and 2D splitting
# print(f"Rank {dist.get_rank()}: Testing 1D splitting")
# model_1d = WanVAE(vae_pth="/data/nvme0/models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", dtype=torch.bfloat16, parallel=True, use_2d_split=False)
# model_1d.to_cuda()
input_tensor
=
torch
.
randn
(
1
,
3
,
17
,
480
,
480
).
to
(
torch
.
bfloat16
).
to
(
"cuda"
)
# encoded_tensor_1d = model_1d.encode(input_tensor)
# print(f"rank {dist.get_rank()} 1D encoded_tensor shape: {encoded_tensor_1d.shape}")
# decoded_tensor_1d = model_1d.decode(encoded_tensor_1d)
# print(f"rank {dist.get_rank()} 1D decoded_tensor shape: {decoded_tensor_1d.shape}")
print
(
f
"Rank
{
dist
.
get_rank
()
}
: Testing 2D splitting"
)
model_2d
=
WanVAE
(
vae_pth
=
"/data/nvme0/models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth"
,
dtype
=
torch
.
bfloat16
,
parallel
=
True
,
use_2d_split
=
True
)
model_2d
.
to_cuda
()
encoded_tensor_2d
=
model_2d
.
encode
(
input_tensor
)
print
(
f
"rank
{
dist
.
get_rank
()
}
2D encoded_tensor shape:
{
encoded_tensor_2d
.
shape
}
"
)
decoded_tensor_2d
=
model_2d
.
decode
(
encoded_tensor_2d
)
print
(
f
"rank
{
dist
.
get_rank
()
}
2D decoded_tensor shape:
{
decoded_tensor_2d
.
shape
}
"
)
# # Verify that both methods produce the same results
# if dist.get_rank() == 0:
# print(f"Encoded tensors match: {torch.allclose(encoded_tensor_1d, encoded_tensor_2d, atol=1e-5)}")
# print(f"Decoded tensors match: {torch.allclose(decoded_tensor_1d, decoded_tensor_2d, atol=1e-5)}")
dist
.
destroy_process_group
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment