"vscode:/vscode.git/clone" did not exist on "6387051411db46f255f8025fa24af9b464cfdc2b"
Unverified Commit f2a2657b authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Add target_fps parameter to video task requests and update inference worker (#528)

parent 4c0a9a0d
...@@ -52,6 +52,7 @@ async def create_video_task_form( ...@@ -52,6 +52,7 @@ async def create_video_task_form(
seed: int = Form(default=42), seed: int = Form(default=42),
audio_file: UploadFile = File(None), audio_file: UploadFile = File(None),
video_duration: int = Form(default=5), video_duration: int = Form(default=5),
target_fps: int = Form(default=16),
): ):
services = get_services() services = get_services()
assert services.file_service is not None, "File service is not initialized" assert services.file_service is not None, "File service is not initialized"
...@@ -89,6 +90,7 @@ async def create_video_task_form( ...@@ -89,6 +90,7 @@ async def create_video_task_form(
seed=seed, seed=seed,
audio_path=audio_path, audio_path=audio_path,
video_duration=video_duration, video_duration=video_duration,
target_fps=target_fps,
) )
try: try:
......
...@@ -40,6 +40,7 @@ class VideoTaskRequest(BaseTaskRequest): ...@@ -40,6 +40,7 @@ class VideoTaskRequest(BaseTaskRequest):
audio_path: str = Field("", description="Input audio path (Wan-Audio)") audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)") video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)") talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
target_fps: Optional[int] = Field(16, description="Target FPS for video frame interpolation (overrides config)")
class ImageTaskRequest(BaseTaskRequest): class ImageTaskRequest(BaseTaskRequest):
...@@ -53,6 +54,7 @@ class TaskRequest(BaseTaskRequest): ...@@ -53,6 +54,7 @@ class TaskRequest(BaseTaskRequest):
video_duration: int = Field(5, description="Video duration (Wan-Audio)") video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)") talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
aspect_ratio: str = Field("16:9", description="Output aspect ratio (T2I only)") aspect_ratio: str = Field("16:9", description="Output aspect ratio (T2I only)")
target_fps: Optional[int] = Field(16, description="Target FPS for video frame interpolation (overrides config)")
class TaskStatusMessage(BaseModel): class TaskStatusMessage(BaseModel):
......
...@@ -55,6 +55,11 @@ class TorchrunInferenceWorker: ...@@ -55,6 +55,11 @@ class TorchrunInferenceWorker:
task_data["return_result_tensor"] = False task_data["return_result_tensor"] = False
task_data["negative_prompt"] = task_data.get("negative_prompt", "") task_data["negative_prompt"] = task_data.get("negative_prompt", "")
if task_data.get("target_fps") is not None and "video_frame_interpolation" in self.runner.config:
task_data["video_frame_interpolation"] = dict(self.runner.config["video_frame_interpolation"])
task_data["video_frame_interpolation"]["target_fps"] = task_data["target_fps"]
del task_data["target_fps"]
task_data = EasyDict(task_data) task_data = EasyDict(task_data)
input_info = set_input_info(task_data) input_info = set_input_info(task_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment