Commit e75d0db7 authored by gaclove's avatar gaclove
Browse files

refactor: replace optimize_latent_size_with_sp with...

refactor: replace optimize_latent_size_with_sp with get_optimal_patched_size_with_sp in WanAudioRunner for improved clarity and functionality
parent 40381d5a
......@@ -40,33 +40,26 @@ def memory_efficient_inference():
gc.collect()
def optimize_latent_size_with_sp(lat_h, lat_w, sp_size, patch_size):
patched_h, patched_w = lat_h // patch_size[0], lat_w // patch_size[1]
if (patched_h * patched_w) % sp_size == 0:
return lat_h, lat_w
else:
h_ratio, w_ratio = 1, 1
h_noevenly_n, w_noevenly_n = 0, 0
h_backup, w_backup = patched_h, patched_w
while sp_size // 2 != 1:
if h_backup % 2 == 0:
h_backup //= 2
h_ratio *= 2
elif w_backup % 2 == 0:
w_backup //= 2
w_ratio *= 2
elif h_noevenly_n <= w_noevenly_n:
h_backup //= 2
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
assert sp_size > 0 and (sp_size & (sp_size - 1)) == 0, "sp_size must be a power of 2"
h_ratio, w_ratio = 1, 1
while sp_size != 1:
sp_size //= 2
if patched_h % 2 == 0:
patched_h //= 2
h_ratio *= 2
elif patched_w % 2 == 0:
patched_w //= 2
w_ratio *= 2
else:
if patched_h > patched_w:
patched_h //= 2
h_ratio *= 2
h_noevenly_n += 1
else:
w_backup //= 2
patched_w //= 2
w_ratio *= 2
w_noevenly_n += 1
sp_size //= 2
new_lat_h = lat_h // h_ratio * h_ratio
new_lat_w = lat_w // w_ratio * w_ratio
return new_lat_h, new_lat_w
return patched_h * h_ratio, patched_w * w_ratio
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
......@@ -95,7 +88,7 @@ def isotropic_crop_resize(frames: torch.Tensor, size: tuple):
h, w = size
y0, y1, x0, x1 = get_crop_bbox(ori_h, ori_w, h, w)
cropped_frames = frames[:, :, y0:y1, x0:x1]
resized_frames = resize(cropped_frames, size, InterpolationMode.BICUBIC, antialias=True)
resized_frames = resize(cropped_frames, [h, w], InterpolationMode.BICUBIC, antialias=True)
return resized_frames
......@@ -389,7 +382,7 @@ class VideoGenerator:
@RUNNER_REGISTER("wan2.1_audio")
class WanAudioRunner(WanRunner):
class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config):
super().__init__(config)
self._audio_adapter_pipe = None
......@@ -461,6 +454,9 @@ class WanAudioRunner(WanRunner):
# Ensure models are initialized
self.initialize()
assert self._audio_processor is not None
assert self._audio_preprocess is not None
# Initialize video generator if needed
if self._video_generator is None:
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
......@@ -579,13 +575,9 @@ class WanAudioRunner(WanRunner):
audio_path = audio_tmp.name
try:
# Save video
save_to_video(images, video_path, fps)
ta.save(audio_path, torch.tensor(audio_array[None]), sample_rate=self._audio_processor.audio_sr) # type: ignore
# Save audio
ta.save(audio_path, torch.tensor(audio_array[None]), sample_rate=self._audio_processor.audio_sr)
# Merge video and audio
output_path = self.config.get("save_video_path")
parent_dir = os.path.dirname(output_path)
if parent_dir and not os.path.exists(parent_dir):
......@@ -639,14 +631,6 @@ class WanAudioRunner(WanRunner):
patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1]
patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2]
patched_h, patched_w = optimize_latent_size_with_sp(patched_h, patched_w, 1, self.config.patch_size[1:])
config.lat_h = patched_h * self.config.patch_size[1]
config.lat_w = patched_w * self.config.patch_size[2]
config.tgt_h = config.lat_h * self.config.vae_stride[1]
config.tgt_w = config.lat_w * self.config.vae_stride[2]
else:
h, w = ref_img.shape[2:]
aspect_ratio = h / w
......@@ -655,19 +639,22 @@ class WanAudioRunner(WanRunner):
patched_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1])
patched_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2])
patched_h, patched_w = optimize_latent_size_with_sp(patched_h, patched_w, 1, config.patch_size[1:])
patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)
config.lat_h = patched_h * config.patch_size[1]
config.lat_w = patched_w * config.patch_size[2]
config.lat_h = patched_h * self.config.patch_size[1]
config.lat_w = patched_w * self.config.patch_size[2]
config.tgt_h = config.lat_h * config.vae_stride[1]
config.tgt_w = config.lat_w * config.vae_stride[2]
config.tgt_h = config.lat_h * self.config.vae_stride[1]
config.tgt_w = config.lat_w * self.config.vae_stride[2]
logger.info(f"[wan_audio] adaptive_resize: {adaptive}, tgt_h: {config.tgt_h}, tgt_w: {config.tgt_w}, lat_h: {config.lat_h}, lat_w: {config.lat_w}")
clip_encoder_out = self.image_encoder.visual([ref_img], self.config).squeeze(0).to(torch.bfloat16)
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
# clip encoder
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16)
# vae encode
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment