schema.py 2.84 KB
Newer Older
PengGao's avatar
PengGao committed
1
import random
2
3
from typing import Optional

PengGao's avatar
PengGao committed
4
5
from pydantic import BaseModel, Field

PengGao's avatar
PengGao committed
6
7
8
from ..utils.generate_task_id import generate_task_id


PengGao's avatar
PengGao committed
9
10
11
12
def generate_random_seed() -> int:
    return random.randint(0, 2**32 - 1)


13
14
15
16
17
class TalkObject(BaseModel):
    audio: str = Field(..., description="Audio path")
    mask: str = Field(..., description="Mask path")


PengGao's avatar
PengGao committed
18
class BaseTaskRequest(BaseModel):
PengGao's avatar
PengGao committed
19
20
21
22
    task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
    prompt: str = Field("", description="Generation prompt")
    use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
    negative_prompt: str = Field("", description="Negative prompt")
gaclove's avatar
gaclove committed
23
    image_path: str = Field("", description="Base64 encoded image or URL")
PengGao's avatar
PengGao committed
24
    save_result_path: str = Field("", description="Save result path (optional, defaults to task_id, suffix auto-detected)")
PengGao's avatar
PengGao committed
25
    infer_steps: int = Field(5, description="Inference steps")
PengGao's avatar
PengGao committed
26
    seed: int = Field(default_factory=generate_random_seed, description="Random seed (auto-generated if not set)")
PengGao's avatar
PengGao committed
27
28
29

    def __init__(self, **data):
        super().__init__(**data)
30
        if not self.save_result_path:
PengGao's avatar
PengGao committed
31
            self.save_result_path = f"{self.task_id}"
PengGao's avatar
PengGao committed
32
33
34
35
36

    def get(self, key, default=None):
        return getattr(self, key, default)


PengGao's avatar
PengGao committed
37
38
39
40
41
42
class VideoTaskRequest(BaseTaskRequest):
    num_fragments: int = Field(1, description="Number of fragments")
    target_video_length: int = Field(81, description="Target video length")
    audio_path: str = Field("", description="Input audio path (Wan-Audio)")
    video_duration: int = Field(5, description="Video duration (Wan-Audio)")
    talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
43
    target_fps: Optional[int] = Field(16, description="Target FPS for video frame interpolation (overrides config)")
PengGao's avatar
PengGao committed
44
45
46
47
48
49
50
51
52
53
54
55
56


class ImageTaskRequest(BaseTaskRequest):
    aspect_ratio: str = Field("16:9", description="Output aspect ratio")


class TaskRequest(BaseTaskRequest):
    num_fragments: int = Field(1, description="Number of fragments")
    target_video_length: int = Field(81, description="Target video length (video only)")
    audio_path: str = Field("", description="Input audio path (Wan-Audio)")
    video_duration: int = Field(5, description="Video duration (Wan-Audio)")
    talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
    aspect_ratio: str = Field("16:9", description="Output aspect ratio (T2I only)")
57
    target_fps: Optional[int] = Field(16, description="Target FPS for video frame interpolation (overrides config)")
PengGao's avatar
PengGao committed
58
59


PengGao's avatar
PengGao committed
60
61
62
63
64
65
66
class TaskStatusMessage(BaseModel):
    task_id: str = Field(..., description="Task ID")


class TaskResponse(BaseModel):
    task_id: str
    task_status: str
67
    save_result_path: str
PengGao's avatar
PengGao committed
68
69
70
71
72


class StopTaskResponse(BaseModel):
    stop_status: str
    reason: str