schema.py 1.77 KB
Newer Older
1
2
from typing import Optional

PengGao's avatar
PengGao committed
3
4
from pydantic import BaseModel, Field

PengGao's avatar
PengGao committed
5
6
7
from ..utils.generate_task_id import generate_task_id


8
9
10
11
12
class TalkObject(BaseModel):
    audio: str = Field(..., description="Audio path")
    mask: str = Field(..., description="Mask path")


PengGao's avatar
PengGao committed
13
14
15
16
17
class TaskRequest(BaseModel):
    task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
    prompt: str = Field("", description="Generation prompt")
    use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
    negative_prompt: str = Field("", description="Negative prompt")
gaclove's avatar
gaclove committed
18
    image_path: str = Field("", description="Base64 encoded image or URL")
PengGao's avatar
PengGao committed
19
    num_fragments: int = Field(1, description="Number of fragments")
20
    save_result_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
PengGao's avatar
PengGao committed
21
22
23
24
25
    infer_steps: int = Field(5, description="Inference steps")
    target_video_length: int = Field(81, description="Target video length")
    seed: int = Field(42, description="Random seed")
    audio_path: str = Field("", description="Input audio path (Wan-Audio)")
    video_duration: int = Field(5, description="Video duration (Wan-Audio)")
26
    talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
PengGao's avatar
PengGao committed
27
28
29

    def __init__(self, **data):
        super().__init__(**data)
30
31
        if not self.save_result_path:
            self.save_result_path = f"{self.task_id}.mp4"
PengGao's avatar
PengGao committed
32
33
34
35
36
37
38
39
40
41
42
43

    def get(self, key, default=None):
        return getattr(self, key, default)


class TaskStatusMessage(BaseModel):
    task_id: str = Field(..., description="Task ID")


class TaskResponse(BaseModel):
    task_id: str
    task_status: str
44
    save_result_path: str
PengGao's avatar
PengGao committed
45
46
47
48
49


class StopTaskResponse(BaseModel):
    stop_status: str
    reason: str