Commit 8c024081 authored by gaclove's avatar gaclove
Browse files

refactor: optimize adaptive resizing logic in WanAudioRunner for improved...

refactor: optimize adaptive resizing logic in WanAudioRunner for improved aspect ratio handling and latent size calculation
parent 87bbed1c
...@@ -625,42 +625,48 @@ class WanAudioRunner(WanRunner): ...@@ -625,42 +625,48 @@ class WanAudioRunner(WanRunner):
ref_img = rearrange(ref_img, "H W C -> 1 C H W") ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3] ref_img = ref_img[:, :3]
if config.get("adaptive_resize", False): adaptive = config.get("adaptive_resize", False)
if adaptive:
# Use adaptive_resize to modify aspect ratio # Use adaptive_resize to modify aspect ratio
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img) ref_img, h, w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1]
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16) patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2]
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") patched_h, patched_w = optimize_latent_size_with_sp(patched_h, patched_w, 1, self.config.patch_size[1:])
lat_h, lat_w = tgt_h // 8, tgt_w // 8
config.lat_h = lat_h config.lat_h = patched_h * self.config.patch_size[1]
config.lat_w = lat_w config.lat_w = patched_w * self.config.patch_size[2]
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list): config.tgt_h = config.lat_h * self.config.vae_stride[1]
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16) config.tgt_w = config.lat_w * self.config.vae_stride[2]
else: else:
h, w = ref_img.shape[2:] h, w = ref_img.shape[2:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = config.target_height * config.target_width max_area = config.target_height * config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
lat_h, lat_w = optimize_latent_size_with_sp(lat_h, lat_w, 1, config.patch_size[1:]) patched_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1])
patched_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2])
patched_h, patched_w = optimize_latent_size_with_sp(patched_h, patched_w, 1, config.patch_size[1:])
config.lat_h = patched_h * config.patch_size[1]
config.lat_w = patched_w * config.patch_size[2]
config.tgt_h = config.lat_h * config.vae_stride[1]
config.tgt_w = config.lat_w * config.vae_stride[2]
config.lat_h, config.lat_w = lat_h, lat_w logger.info(f"[wan_audio] adaptive_resize: {adaptive}, tgt_h: {config.tgt_h}, tgt_w: {config.tgt_w}, lat_h: {config.lat_h}, lat_w: {config.lat_w}")
config.tgt_h = lat_h * config.vae_stride[1]
config.tgt_w = lat_w * config.vae_stride[2]
# Resize image to target size clip_encoder_out = self.image_encoder.visual([ref_img], self.config).squeeze(0).to(torch.bfloat16)
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16)
# Prepare for VAE encoding cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config) vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list): if isinstance(vae_encode_out, list):
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16) vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
return vae_encode_out, clip_encoder_out return vae_encode_out, clip_encoder_out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment