Commit e066bad2 authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

Support vae 2d-grid dist infer & Rewrite FramePreprocessor using torch (#279)

parent 066d7f19
...@@ -162,7 +162,7 @@ class AudioSegment: ...@@ -162,7 +162,7 @@ class AudioSegment:
useful_length: Optional[int] = None useful_length: Optional[int] = None
class FramePreprocessor: class FramePreprocessorTorchVersion:
"""Handles frame preprocessing including noise and masking""" """Handles frame preprocessing including noise and masking"""
def __init__(self, noise_mean: float = -3.0, noise_std: float = 0.5, mask_rate: float = 0.1): def __init__(self, noise_mean: float = -3.0, noise_std: float = 0.5, mask_rate: float = 0.1):
...@@ -170,40 +170,39 @@ class FramePreprocessor: ...@@ -170,40 +170,39 @@ class FramePreprocessor:
self.noise_std = noise_std self.noise_std = noise_std
self.mask_rate = mask_rate self.mask_rate = mask_rate
def add_noise(self, frames: np.ndarray, rnd_state: Optional[np.random.RandomState] = None) -> np.ndarray: def add_noise(self, frames: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
"""Add noise to frames""" """Add noise to frames"""
if self.noise_mean is None or self.noise_std is None:
return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
device = frames.device
shape = frames.shape shape = frames.shape
bs = 1 if len(shape) == 4 else shape[0] bs = 1 if len(shape) == 4 else shape[0]
sigma = rnd_state.normal(loc=self.noise_mean, scale=self.noise_std, size=(bs,))
sigma = np.exp(sigma) # Generate sigma values on the same device
sigma = np.expand_dims(sigma, axis=tuple(range(1, len(shape)))) sigma = torch.normal(mean=self.noise_mean, std=self.noise_std, size=(bs,), device=device, generator=generator)
noise = rnd_state.randn(*shape) * sigma sigma = torch.exp(sigma)
for _ in range(1, len(shape)):
sigma = sigma.unsqueeze(-1)
# Generate noise on the same device
noise = torch.randn(*shape, device=device, generator=generator) * sigma
return frames + noise return frames + noise
def add_mask(self, frames: np.ndarray, rnd_state: Optional[np.random.RandomState] = None) -> np.ndarray: def add_mask(self, frames: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
"""Add mask to frames""" """Add mask to frames"""
if self.mask_rate is None:
return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
device = frames.device
h, w = frames.shape[-2:] h, w = frames.shape[-2:]
mask = rnd_state.rand(h, w) > self.mask_rate
# Generate mask on the same device
mask = torch.rand(h, w, device=device, generator=generator) > self.mask_rate
return frames * mask return frames * mask
def process_prev_frames(self, frames: torch.Tensor) -> torch.Tensor: def process_prev_frames(self, frames: torch.Tensor) -> torch.Tensor:
"""Process previous frames with noise and masking""" """Process previous frames with noise and masking"""
frames_np = frames.cpu().detach().numpy() frames = self.add_noise(frames, torch.Generator(device=frames.device))
frames_np = self.add_noise(frames_np) frames = self.add_mask(frames, torch.Generator(device=frames.device))
frames_np = self.add_mask(frames_np) return frames
return torch.from_numpy(frames_np).to(dtype=frames.dtype, device=frames.device)
class AudioProcessor: class AudioProcessor:
...@@ -283,8 +282,8 @@ class AudioProcessor: ...@@ -283,8 +282,8 @@ class AudioProcessor:
class WanAudioRunner(WanRunner): # type:ignore class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.frame_preprocessor = FramePreprocessor()
self.prev_frame_length = self.config.get("prev_frame_length", 5) self.prev_frame_length = self.config.get("prev_frame_length", 5)
self.frame_preprocessor = FramePreprocessorTorchVersion()
def init_scheduler(self): def init_scheduler(self):
"""Initialize consistency model scheduler""" """Initialize consistency model scheduler"""
...@@ -399,14 +398,15 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -399,14 +398,15 @@ class WanAudioRunner(WanRunner): # type:ignore
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
_, nframe, height, width = self.model.scheduler.latents.shape _, nframe, height, width = self.model.scheduler.latents.shape
if self.config.model_cls == "wan2.2_audio": with ProfilingContext4Debug("vae_encoder in init run segment"):
if prev_video is not None: if self.config.model_cls == "wan2.2_audio":
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype)) if prev_video is not None:
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
else:
prev_latents = None
prev_mask = self.model.scheduler.mask
else: else:
prev_latents = None prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
prev_mask = self.model.scheduler.mask
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
frames_n = (nframe - 1) * 4 + 1 frames_n = (nframe - 1) * 4 + 1
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype) prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
......
import gc
import math import math
import numpy as np import numpy as np
...@@ -99,8 +98,6 @@ class EulerScheduler(WanScheduler): ...@@ -99,8 +98,6 @@ class EulerScheduler(WanScheduler):
self.prev_latents = previmg_encoder_output["prev_latents"] self.prev_latents = previmg_encoder_output["prev_latents"]
self.prev_len = previmg_encoder_output["prev_len"] self.prev_len = previmg_encoder_output["prev_len"]
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim): def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim):
if in_tensor.ndim > tgt_n_dim: if in_tensor.ndim > tgt_n_dim:
......
...@@ -801,12 +801,14 @@ class WanVAE: ...@@ -801,12 +801,14 @@ class WanVAE:
parallel=False, parallel=False,
use_tiling=False, use_tiling=False,
cpu_offload=False, cpu_offload=False,
use_2d_split=True,
): ):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.parallel = parallel self.parallel = parallel
self.use_tiling = use_tiling self.use_tiling = use_tiling
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.use_2d_split = use_2d_split
mean = [ mean = [
-0.7571, -0.7571,
...@@ -848,9 +850,68 @@ class WanVAE: ...@@ -848,9 +850,68 @@ class WanVAE:
self.inv_std = 1.0 / torch.tensor(std, dtype=dtype, device=device) self.inv_std = 1.0 / torch.tensor(std, dtype=dtype, device=device)
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
# (height, width, world_size) -> (world_size_h, world_size_w)
self.grid_table = {
# world_size = 2
(60, 104, 2): (1, 2),
(68, 120, 2): (1, 2),
(90, 160, 2): (1, 2),
(60, 60, 2): (1, 2),
(72, 72, 2): (1, 2),
(88, 88, 2): (1, 2),
(120, 120, 2): (1, 2),
(104, 60, 2): (2, 1),
(120, 68, 2): (2, 1),
(160, 90, 2): (2, 1),
# world_size = 4
(60, 104, 4): (2, 2),
(68, 120, 4): (2, 2),
(90, 160, 4): (2, 2),
(60, 60, 4): (2, 2),
(72, 72, 4): (2, 2),
(88, 88, 4): (2, 2),
(120, 120, 4): (2, 2),
(104, 60, 4): (2, 2),
(120, 68, 4): (2, 2),
(160, 90, 4): (2, 2),
# world_size = 8
(60, 104, 8): (2, 4),
(68, 120, 8): (2, 4),
(90, 160, 8): (2, 4),
(60, 60, 8): (2, 4),
(72, 72, 8): (2, 4),
(88, 88, 8): (2, 4),
(120, 120, 8): (2, 4),
(104, 60, 8): (4, 2),
(120, 68, 8): (4, 2),
(160, 90, 8): (4, 2),
}
# init model # init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype).eval().requires_grad_(False).to(device).to(dtype) self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype).eval().requires_grad_(False).to(device).to(dtype)
def _calculate_2d_grid(self, latent_height, latent_width, world_size):
if (latent_height, latent_width, world_size) in self.grid_table:
best_h, best_w = self.grid_table[(latent_height, latent_width, world_size)]
logger.info(f"Vae using cached 2D grid: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
return best_h, best_w
best_h, best_w = 1, world_size
min_aspect_diff = float("inf")
for h in range(1, world_size + 1):
if world_size % h == 0:
w = world_size // h
if latent_height % h == 0 and latent_width % w == 0:
# Calculate how close this grid is to square
aspect_diff = abs((latent_height / h) - (latent_width / w))
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
best_h, best_w = h, w
logger.info(f"Vae using 2D grid & Update cache: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
self.grid_table[(latent_height, latent_width, world_size)] = (best_h, best_w)
return best_h, best_w
def current_device(self): def current_device(self):
return next(self.model.parameters()).device return next(self.model.parameters()).device
...@@ -934,6 +995,97 @@ class WanVAE: ...@@ -934,6 +995,97 @@ class WanVAE:
return encoded.squeeze(0) return encoded.squeeze(0)
def encode_dist_2d(self, video, world_size_h, world_size_w, cur_rank_h, cur_rank_w):
spatial_ratio = 8
# Calculate chunk sizes for both dimensions
total_latent_h = video.shape[3] // spatial_ratio
total_latent_w = video.shape[4] // spatial_ratio
chunk_h = total_latent_h // world_size_h
chunk_w = total_latent_w // world_size_w
padding_size = 1
video_chunk_h = chunk_h * spatial_ratio
video_chunk_w = chunk_w * spatial_ratio
video_padding_h = padding_size * spatial_ratio
video_padding_w = padding_size * spatial_ratio
# Calculate H dimension slice
if cur_rank_h == 0:
h_start = 0
h_end = video_chunk_h + 2 * video_padding_h
elif cur_rank_h == world_size_h - 1:
h_start = video.shape[3] - (video_chunk_h + 2 * video_padding_h)
h_end = video.shape[3]
else:
h_start = cur_rank_h * video_chunk_h - video_padding_h
h_end = (cur_rank_h + 1) * video_chunk_h + video_padding_h
# Calculate W dimension slice
if cur_rank_w == 0:
w_start = 0
w_end = video_chunk_w + 2 * video_padding_w
elif cur_rank_w == world_size_w - 1:
w_start = video.shape[4] - (video_chunk_w + 2 * video_padding_w)
w_end = video.shape[4]
else:
w_start = cur_rank_w * video_chunk_w - video_padding_w
w_end = (cur_rank_w + 1) * video_chunk_w + video_padding_w
# Extract the video chunk for this process
video_chunk = video[:, :, :, h_start:h_end, w_start:w_end].contiguous()
# Encode the chunk
if self.use_tiling:
encoded_chunk = self.model.tiled_encode(video_chunk, self.scale)
else:
encoded_chunk = self.model.encode(video_chunk, self.scale)
# Remove padding from encoded chunk
if cur_rank_h == 0:
encoded_h_start = 0
encoded_h_end = chunk_h
elif cur_rank_h == world_size_h - 1:
encoded_h_start = encoded_chunk.shape[3] - chunk_h
encoded_h_end = encoded_chunk.shape[3]
else:
encoded_h_start = padding_size
encoded_h_end = encoded_chunk.shape[3] - padding_size
if cur_rank_w == 0:
encoded_w_start = 0
encoded_w_end = chunk_w
elif cur_rank_w == world_size_w - 1:
encoded_w_start = encoded_chunk.shape[4] - chunk_w
encoded_w_end = encoded_chunk.shape[4]
else:
encoded_w_start = padding_size
encoded_w_end = encoded_chunk.shape[4] - padding_size
encoded_chunk = encoded_chunk[:, :, :, encoded_h_start:encoded_h_end, encoded_w_start:encoded_w_end].contiguous()
# Gather all chunks
total_processes = world_size_h * world_size_w
full_encoded = [torch.empty_like(encoded_chunk) for _ in range(total_processes)]
dist.all_gather(full_encoded, encoded_chunk)
torch.cuda.synchronize()
# Reconstruct the full encoded tensor
encoded_rows = []
for h_idx in range(world_size_h):
encoded_cols = []
for w_idx in range(world_size_w):
process_idx = h_idx * world_size_w + w_idx
encoded_cols.append(full_encoded[process_idx])
encoded_rows.append(torch.cat(encoded_cols, dim=4))
encoded = torch.cat(encoded_rows, dim=3)
return encoded.squeeze(0)
def encode(self, video): def encode(self, video):
""" """
video: one video with shape [1, C, T, H, W]. video: one video with shape [1, C, T, H, W].
...@@ -946,17 +1098,23 @@ class WanVAE: ...@@ -946,17 +1098,23 @@ class WanVAE:
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
height, width = video.shape[3], video.shape[4] height, width = video.shape[3], video.shape[4]
# Check if dimensions are divisible by world_size if self.use_2d_split:
if width % world_size == 0: world_size_h, world_size_w = self._calculate_2d_grid(height // 8, width // 8, world_size)
out = self.encode_dist(video, world_size, cur_rank, split_dim=4) cur_rank_h = cur_rank // world_size_w
elif height % world_size == 0: cur_rank_w = cur_rank % world_size_w
out = self.encode_dist(video, world_size, cur_rank, split_dim=3) out = self.encode_dist_2d(video, world_size_h, world_size_w, cur_rank_h, cur_rank_w)
else: else:
logger.info("Fall back to naive encode mode") # Original 1D splitting logic
if self.use_tiling: if width % world_size == 0:
out = self.model.tiled_encode(video, self.scale).squeeze(0) out = self.encode_dist(video, world_size, cur_rank, split_dim=4)
elif height % world_size == 0:
out = self.encode_dist(video, world_size, cur_rank, split_dim=3)
else: else:
out = self.model.encode(video, self.scale).squeeze(0) logger.info("Fall back to naive encode mode")
if self.use_tiling:
out = self.model.tiled_encode(video, self.scale).squeeze(0)
else:
out = self.model.encode(video, self.scale).squeeze(0)
else: else:
if self.use_tiling: if self.use_tiling:
out = self.model.tiled_encode(video, self.scale).squeeze(0) out = self.model.tiled_encode(video, self.scale).squeeze(0)
...@@ -1016,6 +1174,89 @@ class WanVAE: ...@@ -1016,6 +1174,89 @@ class WanVAE:
return images return images
def decode_dist_2d(self, zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w):
total_h = zs.shape[2]
total_w = zs.shape[3]
chunk_h = total_h // world_size_h
chunk_w = total_w // world_size_w
padding_size = 1
# Calculate H dimension slice
if cur_rank_h == 0:
h_start = 0
h_end = chunk_h + 2 * padding_size
elif cur_rank_h == world_size_h - 1:
h_start = total_h - (chunk_h + 2 * padding_size)
h_end = total_h
else:
h_start = cur_rank_h * chunk_h - padding_size
h_end = (cur_rank_h + 1) * chunk_h + padding_size
# Calculate W dimension slice
if cur_rank_w == 0:
w_start = 0
w_end = chunk_w + 2 * padding_size
elif cur_rank_w == world_size_w - 1:
w_start = total_w - (chunk_w + 2 * padding_size)
w_end = total_w
else:
w_start = cur_rank_w * chunk_w - padding_size
w_end = (cur_rank_w + 1) * chunk_w + padding_size
# Extract the latent chunk for this process
zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous()
# Decode the chunk
decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
images_chunk = decode_func(zs_chunk.unsqueeze(0), self.scale).clamp_(-1, 1)
# Remove padding from decoded chunk
spatial_ratio = 8
if cur_rank_h == 0:
decoded_h_start = 0
decoded_h_end = chunk_h * spatial_ratio
elif cur_rank_h == world_size_h - 1:
decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
decoded_h_end = images_chunk.shape[3]
else:
decoded_h_start = padding_size * spatial_ratio
decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio
if cur_rank_w == 0:
decoded_w_start = 0
decoded_w_end = chunk_w * spatial_ratio
elif cur_rank_w == world_size_w - 1:
decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
decoded_w_end = images_chunk.shape[4]
else:
decoded_w_start = padding_size * spatial_ratio
decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio
images_chunk = images_chunk[:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end].contiguous()
# Gather all chunks
total_processes = world_size_h * world_size_w
full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)]
dist.all_gather(full_images, images_chunk)
torch.cuda.synchronize()
# Reconstruct the full image tensor
image_rows = []
for h_idx in range(world_size_h):
image_cols = []
for w_idx in range(world_size_w):
process_idx = h_idx * world_size_w + w_idx
image_cols.append(full_images[process_idx])
image_rows.append(torch.cat(image_cols, dim=4))
images = torch.cat(image_rows, dim=3)
return images
def decode(self, zs): def decode(self, zs):
if self.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
...@@ -1023,15 +1264,22 @@ class WanVAE: ...@@ -1023,15 +1264,22 @@ class WanVAE:
if self.parallel: if self.parallel:
world_size = dist.get_world_size() world_size = dist.get_world_size()
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
height, width = zs.shape[2], zs.shape[3] latent_height, latent_width = zs.shape[2], zs.shape[3]
if width % world_size == 0: if self.use_2d_split:
images = self.decode_dist(zs, world_size, cur_rank, split_dim=3) world_size_h, world_size_w = self._calculate_2d_grid(latent_height, latent_width, world_size)
elif height % world_size == 0: cur_rank_h = cur_rank // world_size_w
images = self.decode_dist(zs, world_size, cur_rank, split_dim=2) cur_rank_w = cur_rank % world_size_w
images = self.decode_dist_2d(zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w)
else: else:
logger.info("Fall back to naive decode mode") # Original 1D splitting logic
images = self.model.decode(zs.unsqueeze(0), self.scale).clamp_(-1, 1) if latent_width % world_size == 0:
images = self.decode_dist(zs, world_size, cur_rank, split_dim=3)
elif latent_height % world_size == 0:
images = self.decode_dist(zs, world_size, cur_rank, split_dim=2)
else:
logger.info("Fall back to naive decode mode")
images = self.model.decode(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
else: else:
decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1) images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
...@@ -1041,3 +1289,35 @@ class WanVAE: ...@@ -1041,3 +1289,35 @@ class WanVAE:
self.to_cpu() self.to_cpu()
return images return images
if __name__ == "__main__":
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
# # Test both 1D and 2D splitting
# print(f"Rank {dist.get_rank()}: Testing 1D splitting")
# model_1d = WanVAE(vae_pth="/data/nvme0/models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", dtype=torch.bfloat16, parallel=True, use_2d_split=False)
# model_1d.to_cuda()
input_tensor = torch.randn(1, 3, 17, 480, 480).to(torch.bfloat16).to("cuda")
# encoded_tensor_1d = model_1d.encode(input_tensor)
# print(f"rank {dist.get_rank()} 1D encoded_tensor shape: {encoded_tensor_1d.shape}")
# decoded_tensor_1d = model_1d.decode(encoded_tensor_1d)
# print(f"rank {dist.get_rank()} 1D decoded_tensor shape: {decoded_tensor_1d.shape}")
print(f"Rank {dist.get_rank()}: Testing 2D splitting")
model_2d = WanVAE(vae_pth="/data/nvme0/models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", dtype=torch.bfloat16, parallel=True, use_2d_split=True)
model_2d.to_cuda()
encoded_tensor_2d = model_2d.encode(input_tensor)
print(f"rank {dist.get_rank()} 2D encoded_tensor shape: {encoded_tensor_2d.shape}")
decoded_tensor_2d = model_2d.decode(encoded_tensor_2d)
print(f"rank {dist.get_rank()} 2D decoded_tensor shape: {decoded_tensor_2d.shape}")
# # Verify that both methods produce the same results
# if dist.get_rank() == 0:
# print(f"Encoded tensors match: {torch.allclose(encoded_tensor_1d, encoded_tensor_2d, atol=1e-5)}")
# print(f"Decoded tensors match: {torch.allclose(decoded_tensor_1d, decoded_tensor_2d, atol=1e-5)}")
dist.destroy_process_group()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment