schema.py 1.5 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
from pydantic import BaseModel, Field

PengGao's avatar
PengGao committed
3
4
5
6
7
8
9
10
from ..utils.generate_task_id import generate_task_id


class TaskRequest(BaseModel):
    task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
    prompt: str = Field("", description="Generation prompt")
    use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
    negative_prompt: str = Field("", description="Negative prompt")
gaclove's avatar
gaclove committed
11
    image_path: str = Field("", description="Base64 encoded image or URL")
PengGao's avatar
PengGao committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    num_fragments: int = Field(1, description="Number of fragments")
    save_video_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
    infer_steps: int = Field(5, description="Inference steps")
    target_video_length: int = Field(81, description="Target video length")
    seed: int = Field(42, description="Random seed")
    audio_path: str = Field("", description="Input audio path (Wan-Audio)")
    video_duration: int = Field(5, description="Video duration (Wan-Audio)")

    def __init__(self, **data):
        super().__init__(**data)
        if not self.save_video_path:
            self.save_video_path = f"{self.task_id}.mp4"

    def get(self, key, default=None):
        return getattr(self, key, default)


class TaskStatusMessage(BaseModel):
    task_id: str = Field(..., description="Task ID")


class TaskResponse(BaseModel):
    task_id: str
    task_status: str
    save_video_path: str


class StopTaskResponse(BaseModel):
    stop_status: str
    reason: str