schema.py 1.51 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
from pydantic import BaseModel, Field

PengGao's avatar
PengGao committed
3
4
5
6
7
8
9
10
from ..utils.generate_task_id import generate_task_id


class TaskRequest(BaseModel):
    task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
    prompt: str = Field("", description="Generation prompt")
    use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
    negative_prompt: str = Field("", description="Negative prompt")
gaclove's avatar
gaclove committed
11
    image_path: str = Field("", description="Base64 encoded image or URL")
PengGao's avatar
PengGao committed
12
    num_fragments: int = Field(1, description="Number of fragments")
13
    save_result_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
PengGao's avatar
PengGao committed
14
15
16
17
18
19
20
21
    infer_steps: int = Field(5, description="Inference steps")
    target_video_length: int = Field(81, description="Target video length")
    seed: int = Field(42, description="Random seed")
    audio_path: str = Field("", description="Input audio path (Wan-Audio)")
    video_duration: int = Field(5, description="Video duration (Wan-Audio)")

    def __init__(self, **data):
        super().__init__(**data)
22
23
        if not self.save_result_path:
            self.save_result_path = f"{self.task_id}.mp4"
PengGao's avatar
PengGao committed
24
25
26
27
28
29
30
31
32
33
34
35

    def get(self, key, default=None):
        return getattr(self, key, default)


class TaskStatusMessage(BaseModel):
    task_id: str = Field(..., description="Task ID")


class TaskResponse(BaseModel):
    task_id: str
    task_status: str
36
    save_result_path: str
PengGao's avatar
PengGao committed
37
38
39
40
41


class StopTaskResponse(BaseModel):
    stop_status: str
    reason: str