Unverified Commit aa627f77 authored by Musisoul's avatar Musisoul Committed by GitHub
Browse files

Fix vae parallel bug (#548)

parent a1a1a8c0
...@@ -271,6 +271,58 @@ class WanRunner(DefaultRunner): ...@@ -271,6 +271,58 @@ class WanRunner(DefaultRunner):
gc.collect() gc.collect()
return clip_encoder_out return clip_encoder_out
def _adjust_latent_for_grid_splitting(self, latent_h, latent_w, world_size):
"""
Adjust latent dimensions for optimal 2D grid splitting.
Prefers balanced grids like 2x4 or 4x2 over 1x8 or 8x1.
"""
world_size_h, world_size_w = 1, 1
if world_size <= 1:
return latent_h, latent_w, world_size_h, world_size_w
# Define priority grids for different world sizes
priority_grids = []
if world_size == 8:
# For 8 cards, prefer 2x4 and 4x2 over 1x8 and 8x1
priority_grids = [(2, 4), (4, 2), (1, 8), (8, 1)]
elif world_size == 4:
priority_grids = [(2, 2), (1, 4), (4, 1)]
elif world_size == 2:
priority_grids = [(1, 2), (2, 1)]
else:
# For other sizes, try factor pairs
for h in range(1, int(np.sqrt(world_size)) + 1):
if world_size % h == 0:
w = world_size // h
priority_grids.append((h, w))
# Try priority grids first
for world_size_h, world_size_w in priority_grids:
if latent_h % world_size_h == 0 and latent_w % world_size_w == 0:
return latent_h, latent_w, world_size_h, world_size_w
# If no perfect fit, find minimal padding solution
best_grid = (1, world_size) # fallback
min_total_padding = float("inf")
for world_size_h, world_size_w in priority_grids:
# Calculate required padding
pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h
pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w
total_padding = pad_h + pad_w
# Prefer grids with minimal total padding
if total_padding < min_total_padding:
min_total_padding = total_padding
best_grid = (world_size_h, world_size_w)
# Apply padding
world_size_h, world_size_w = best_grid
pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h
pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w
return latent_h + pad_h, latent_w + pad_w, world_size_h, world_size_w
@ProfilingContext4DebugL1( @ProfilingContext4DebugL1(
"Run VAE Encoder", "Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(), recorder_mode=GET_RECORDER_MODE(),
...@@ -281,8 +333,19 @@ class WanRunner(DefaultRunner): ...@@ -281,8 +333,19 @@ class WanRunner(DefaultRunner):
h, w = first_frame.shape[2:] h, w = first_frame.shape[2:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = self.config["target_height"] * self.config["target_width"] max_area = self.config["target_height"] * self.config["target_width"]
latent_h = round(np.sqrt(max_area * aspect_ratio) // self.config["vae_stride"][1] // self.config["patch_size"][1] * self.config["patch_size"][1])
latent_w = round(np.sqrt(max_area / aspect_ratio) // self.config["vae_stride"][2] // self.config["patch_size"][2] * self.config["patch_size"][2]) # Calculate initial latent dimensions
ori_latent_h = round(np.sqrt(max_area * aspect_ratio) // self.config["vae_stride"][1] // self.config["patch_size"][1] * self.config["patch_size"][1])
ori_latent_w = round(np.sqrt(max_area / aspect_ratio) // self.config["vae_stride"][2] // self.config["patch_size"][2] * self.config["patch_size"][2])
# Adjust latent dimensions for optimal 2D grid splitting when using distributed processing
if dist.is_initialized() and dist.get_world_size() > 1:
latent_h, latent_w, world_size_h, world_size_w = self._adjust_latent_for_grid_splitting(ori_latent_h, ori_latent_w, dist.get_world_size())
logger.info(f"ori latent: {ori_latent_h}x{ori_latent_w}, adjust_latent: {latent_h}x{latent_w}, grid: {world_size_h}x{world_size_w}")
else:
latent_h, latent_w = ori_latent_h, ori_latent_w
world_size_h, world_size_w = None, None
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) # Important: latent_shape is used to set the input_info latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) # Important: latent_shape is used to set the input_info
if self.config.get("changing_resolution", False): if self.config.get("changing_resolution", False):
...@@ -293,8 +356,8 @@ class WanRunner(DefaultRunner): ...@@ -293,8 +356,8 @@ class WanRunner(DefaultRunner):
int(latent_h * self.config["resolution_rate"][i]) // 2 * 2, int(latent_h * self.config["resolution_rate"][i]) // 2 * 2,
int(latent_w * self.config["resolution_rate"][i]) // 2 * 2, int(latent_w * self.config["resolution_rate"][i]) // 2 * 2,
) )
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h_tmp, latent_w_tmp)) vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h_tmp, latent_w_tmp, world_size_h=world_size_h, world_size_w=world_size_w))
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h, latent_w)) vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h, latent_w, world_size_h=world_size_h, world_size_w=world_size_w))
return vae_encode_out_list, latent_shape return vae_encode_out_list, latent_shape
else: else:
if last_frame is not None: if last_frame is not None:
...@@ -307,10 +370,10 @@ class WanRunner(DefaultRunner): ...@@ -307,10 +370,10 @@ class WanRunner(DefaultRunner):
round(last_frame_size[1] * last_frame_resize_ratio), round(last_frame_size[1] * last_frame_resize_ratio),
] ]
last_frame = TF.center_crop(last_frame, last_frame_size) last_frame = TF.center_crop(last_frame, last_frame_size)
vae_encoder_out = self.get_vae_encoder_output(first_frame, latent_h, latent_w, last_frame) vae_encoder_out = self.get_vae_encoder_output(first_frame, latent_h, latent_w, last_frame, world_size_h=world_size_h, world_size_w=world_size_w)
return vae_encoder_out, latent_shape return vae_encoder_out, latent_shape
def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None): def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None, world_size_h=None, world_size_w=None):
h = lat_h * self.config["vae_stride"][1] h = lat_h * self.config["vae_stride"][1]
w = lat_w * self.config["vae_stride"][2] w = lat_w * self.config["vae_stride"][2]
msk = torch.ones( msk = torch.ones(
...@@ -350,7 +413,7 @@ class WanRunner(DefaultRunner): ...@@ -350,7 +413,7 @@ class WanRunner(DefaultRunner):
dim=1, dim=1,
).to(AI_DEVICE) ).to(AI_DEVICE)
vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE())) vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()), world_size_h=world_size_h, world_size_w=world_size_w)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder del self.vae_encoder
......
...@@ -1119,7 +1119,7 @@ class WanVAE: ...@@ -1119,7 +1119,7 @@ class WanVAE:
return encoded.squeeze(0) return encoded.squeeze(0)
def encode(self, video): def encode(self, video, world_size_h=None, world_size_w=None):
""" """
video: one video with shape [1, C, T, H, W]. video: one video with shape [1, C, T, H, W].
""" """
...@@ -1132,7 +1132,8 @@ class WanVAE: ...@@ -1132,7 +1132,8 @@ class WanVAE:
height, width = video.shape[3], video.shape[4] height, width = video.shape[3], video.shape[4]
if self.use_2d_split: if self.use_2d_split:
world_size_h, world_size_w = self._calculate_2d_grid(height // 8, width // 8, world_size) if world_size_h is None or world_size_w is None:
world_size_h, world_size_w = self._calculate_2d_grid(height // 8, width // 8, world_size)
cur_rank_h = cur_rank // world_size_w cur_rank_h = cur_rank // world_size_w
cur_rank_w = cur_rank % world_size_w cur_rank_w = cur_rank % world_size_w
out = self.encode_dist_2d(video, world_size_h, world_size_w, cur_rank_h, cur_rank_w) out = self.encode_dist_2d(video, world_size_h, world_size_w, cur_rank_h, cur_rank_w)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment