"vscode:/vscode.git/clone" did not exist on "a2b9a52367401e5039d8c572f0fc23bb7295a538"
Unverified Commit 7be16e43 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

api support resize mode (#545)

parent d4e82934
...@@ -169,19 +169,29 @@ def resize_image(img, resize_mode="adaptive", bucket_shape=None, fixed_area=None ...@@ -169,19 +169,29 @@ def resize_image(img, resize_mode="adaptive", bucket_shape=None, fixed_area=None
if ori_height * ori_weight >= resolution[0] * resolution[1]: if ori_height * ori_weight >= resolution[0] * resolution[1]:
target_h, target_w = resolution target_h, target_w = resolution
elif resize_mode == "keep_ratio_fixed_area": elif resize_mode == "keep_ratio_fixed_area":
assert fixed_area in ["480p", "720p"], f"fixed_area must be in ['480p', '720p'], but got {fixed_area}, please set fixed_area in config." area_in_pixels = 480 * 832
fixed_area = 480 * 832 if fixed_area == "480p" else 720 * 1280 if fixed_area == "480p":
target_h = round(np.sqrt(fixed_area * ori_ratio)) area_in_pixels = 480 * 832
target_w = round(np.sqrt(fixed_area / ori_ratio)) elif fixed_area == "720p":
area_in_pixels = 720 * 1280
else:
area_in_pixels = 480 * 832
target_h = round(np.sqrt(area_in_pixels * ori_ratio))
target_w = round(np.sqrt(area_in_pixels / ori_ratio))
elif resize_mode == "fixed_min_area": elif resize_mode == "fixed_min_area":
aspect_ratios = np.array(np.array(list(bucket_config.keys()))) aspect_ratios = np.array(np.array(list(bucket_config.keys())))
closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio)) closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
closet_ratio = aspect_ratios[closet_aspect_idx] closet_ratio = aspect_ratios[closet_aspect_idx]
target_h, target_w = bucket_config[closet_ratio][0] target_h, target_w = bucket_config[closet_ratio][0]
elif resize_mode == "fixed_min_side": elif resize_mode == "fixed_min_side":
assert fixed_area in ["480p", "720p"], f"fixed_min_side mode requires fixed_area to be '480p' or '720p', got {fixed_area}" min_side = 720
if fixed_area == "720p":
min_side = 720 if fixed_area == "720p" else 480 min_side = 720
elif fixed_area == "480p":
min_side = 480
else:
logger.warning(f"[wan_audio] fixed_area is not '480p' or '720p', using default 480p: {fixed_area}")
min_side = 480
if ori_ratio < 1.0: if ori_ratio < 1.0:
target_h = min_side target_h = min_side
target_w = round(target_h / ori_ratio) target_w = round(target_h / ori_ratio)
...@@ -195,6 +205,7 @@ def resize_image(img, resize_mode="adaptive", bucket_shape=None, fixed_area=None ...@@ -195,6 +205,7 @@ def resize_image(img, resize_mode="adaptive", bucket_shape=None, fixed_area=None
target_h, target_w = bucket_config[closet_ratio][-1] target_h, target_w = bucket_config[closet_ratio][-1]
cropped_img = isotropic_crop_resize(img, (target_h, target_w)) cropped_img = isotropic_crop_resize(img, (target_h, target_w))
logger.info(f"[wan_audio] resize_image: {img.shape} -> {cropped_img.shape}, resize_mode: {resize_mode}, target_h: {target_h}, target_w: {target_w}")
return cropped_img, target_h, target_w return cropped_img, target_h, target_w
......
...@@ -41,6 +41,7 @@ class VideoTaskRequest(BaseTaskRequest): ...@@ -41,6 +41,7 @@ class VideoTaskRequest(BaseTaskRequest):
video_duration: int = Field(5, description="Video duration (Wan-Audio)") video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)") talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
target_fps: Optional[int] = Field(16, description="Target FPS for video frame interpolation (overrides config)") target_fps: Optional[int] = Field(16, description="Target FPS for video frame interpolation (overrides config)")
resize_mode: Optional[str] = Field("adaptive", description="Resize mode (adaptive, keep_ratio_fixed_area, fixed_min_area, fixed_max_area, fixed_shape, fixed_min_side)")
class ImageTaskRequest(BaseTaskRequest): class ImageTaskRequest(BaseTaskRequest):
......
...@@ -118,6 +118,7 @@ class BaseGenerationService(ABC): ...@@ -118,6 +118,7 @@ class BaseGenerationService(ABC):
self._prepare_output_path(message.save_result_path, task_data) self._prepare_output_path(message.save_result_path, task_data)
task_data["seed"] = message.seed task_data["seed"] = message.seed
task_data["resize_mode"] = message.resize_mode
result = await self.inference_service.submit_task_async(task_data) result = await self.inference_service.submit_task_async(task_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment