Commit 048be946 authored by gaclove's avatar gaclove
Browse files

feat: add adaptive resizing configuration and implement new resizing functions in WanAudioRunner

parent d048b178
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 16,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale":1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false
"infer_steps": 4,
"target_fps": 16,
"video_duration": 16,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 16,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"adaptive_resize": true
}
......@@ -51,6 +51,90 @@ def memory_efficient_inference():
gc.collect()
def optimize_latent_size_with_sp(lat_h, lat_w, sp_size, patch_size):
patched_h, patched_w = lat_h // patch_size[0], lat_w // patch_size[1]
if (patched_h * patched_w) % sp_size == 0:
return lat_h, lat_w
else:
h_ratio, w_ratio = 1, 1
h_noevenly_n, w_noevenly_n = 0, 0
h_backup, w_backup = patched_h, patched_w
while sp_size // 2 != 1:
if h_backup % 2 == 0:
h_backup //= 2
h_ratio *= 2
elif w_backup % 2 == 0:
w_backup //= 2
w_ratio *= 2
elif h_noevenly_n <= w_noevenly_n:
h_backup //= 2
h_ratio *= 2
h_noevenly_n += 1
else:
w_backup //= 2
w_ratio *= 2
w_noevenly_n += 1
sp_size //= 2
new_lat_h = lat_h // h_ratio * h_ratio
new_lat_w = lat_w // w_ratio * w_ratio
return new_lat_h, new_lat_w
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
tgt_ar = tgt_h / tgt_w
ori_ar = ori_h / ori_w
if abs(ori_ar - tgt_ar) < 0.01:
return 0, ori_h, 0, ori_w
if ori_ar > tgt_ar:
crop_h = int(tgt_ar * ori_w)
y0 = (ori_h - crop_h) // 2
y1 = y0 + crop_h
return y0, y1, 0, ori_w
else:
crop_w = int(ori_h / tgt_ar)
x0 = (ori_w - crop_w) // 2
x1 = x0 + crop_w
return 0, ori_h, x0, x1
def isotropic_crop_resize(frames: torch.Tensor, size: tuple):
"""
frames: (T, C, H, W)
size: (H, W)
"""
ori_h, ori_w = frames.shape[2:]
h, w = size
y0, y1, x0, x1 = get_crop_bbox(ori_h, ori_w, h, w)
cropped_frames = frames[:, :, y0:y1, x0:x1]
resized_frames = resize(cropped_frames, size, InterpolationMode.BICUBIC, antialias=True)
return resized_frames
def adaptive_resize(img):
bucket_config = {
0.667: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64), np.array([0.2, 0.5, 0.3])),
1.0: (np.array([[480, 480], [576, 576], [704, 704], [960, 960]], dtype=np.int64), np.array([0.1, 0.1, 0.5, 0.3])),
1.5: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64)[:, ::-1], np.array([0.2, 0.5, 0.3])),
}
ori_height = img.shape[-2]
ori_weight = img.shape[-1]
ori_ratio = ori_height / ori_weight
aspect_ratios = np.array(np.array(list(bucket_config.keys())))
closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
closet_ratio = aspect_ratios[closet_aspect_idx]
if ori_ratio < 1.0:
target_h, target_w = 480, 832
elif ori_ratio == 1.0:
target_h, target_w = 480, 480
else:
target_h, target_w = 832, 480
for resolution in bucket_config[closet_ratio][0]:
if ori_height * ori_weight >= resolution[0] * resolution[1]:
target_h, target_w = resolution
cropped_img = isotropic_crop_resize(img, (target_h, target_w))
return cropped_img, target_h, target_w
@dataclass
class AudioSegment:
"""Data class for audio segment information"""
......@@ -300,61 +384,6 @@ class VideoGenerator:
return gen_video
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
tgt_ar = tgt_h / tgt_w
ori_ar = ori_h / ori_w
if abs(ori_ar - tgt_ar) < 0.01:
return 0, ori_h, 0, ori_w
if ori_ar > tgt_ar:
crop_h = int(tgt_ar * ori_w)
y0 = (ori_h - crop_h) // 2
y1 = y0 + crop_h
return y0, y1, 0, ori_w
else:
crop_w = int(ori_h / tgt_ar)
x0 = (ori_w - crop_w) // 2
x1 = x0 + crop_w
return 0, ori_h, x0, x1
def isotropic_crop_resize(frames: torch.Tensor, size: tuple):
"""
frames: (T, C, H, W)
size: (H, W)
"""
ori_h, ori_w = frames.shape[2:]
h, w = size
y0, y1, x0, x1 = get_crop_bbox(ori_h, ori_w, h, w)
cropped_frames = frames[:, :, y0:y1, x0:x1]
resized_frames = resize(cropped_frames, size, InterpolationMode.BICUBIC, antialias=True)
return resized_frames
def adaptive_resize(img):
bucket_config = {
0.667: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64), np.array([0.2, 0.5, 0.3])),
1.0: (np.array([[480, 480], [576, 576], [704, 704], [960, 960]], dtype=np.int64), np.array([0.1, 0.1, 0.5, 0.3])),
1.5: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64)[:, ::-1], np.array([0.2, 0.5, 0.3])),
}
ori_height = img.shape[-2]
ori_weight = img.shape[-1]
ori_ratio = ori_height / ori_weight
aspect_ratios = np.array(np.array(list(bucket_config.keys())))
closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
closet_ratio = aspect_ratios[closet_aspect_idx]
if ori_ratio < 1.0:
target_h, target_w = 480, 832
elif ori_ratio == 1.0:
target_h, target_w = 480, 480
else:
target_h, target_w = 832, 480
for resolution in bucket_config[closet_ratio][0]:
if ori_height * ori_weight >= resolution[0] * resolution[1]:
target_h, target_w = resolution
cropped_img = isotropic_crop_resize(img, (target_h, target_w))
return cropped_img, target_h, target_w
@RUNNER_REGISTER("wan2.1_audio")
class WanAudioRunner(WanRunner):
def __init__(self, config):
......@@ -604,19 +633,42 @@ class WanAudioRunner(WanRunner):
ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3]
# Resize and crop image
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8
config.lat_h = lat_h
config.lat_w = lat_w
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list):
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
if config.get("adaptive_resize", False):
# Use adaptive_resize to modify aspect ratio
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8
config.lat_h = lat_h
config.lat_w = lat_w
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list):
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
else:
h, w = ref_img.shape[2:]
aspect_ratio = h / w
max_area = config.target_height * config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
lat_h, lat_w = optimize_latent_size_with_sp(lat_h, lat_w, 1, config.patch_size[1:])
config.lat_h, config.lat_w = lat_h, lat_w
config.tgt_h = lat_h * config.vae_stride[1]
config.tgt_w = lat_w * config.vae_stride[2]
# Resize image to target size
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
# Prepare for VAE encoding
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list):
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
return vae_encode_out, clip_encoder_out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment