schema.py 2.61 KB
Newer Older
PengGao's avatar
PengGao committed
1
import random
2
3
from typing import Optional

PengGao's avatar
PengGao committed
4
5
from pydantic import BaseModel, Field

PengGao's avatar
PengGao committed
6
7
8
from ..utils.generate_task_id import generate_task_id


PengGao's avatar
PengGao committed
9
10
11
12
def generate_random_seed() -> int:
    return random.randint(0, 2**32 - 1)


13
14
15
16
17
class TalkObject(BaseModel):
    audio: str = Field(..., description="Audio path")
    mask: str = Field(..., description="Mask path")


PengGao's avatar
PengGao committed
18
class BaseTaskRequest(BaseModel):
PengGao's avatar
PengGao committed
19
20
21
22
    task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
    prompt: str = Field("", description="Generation prompt")
    use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
    negative_prompt: str = Field("", description="Negative prompt")
gaclove's avatar
gaclove committed
23
    image_path: str = Field("", description="Base64 encoded image or URL")
PengGao's avatar
PengGao committed
24
    save_result_path: str = Field("", description="Save result path (optional, defaults to task_id, suffix auto-detected)")
PengGao's avatar
PengGao committed
25
    infer_steps: int = Field(5, description="Inference steps")
PengGao's avatar
PengGao committed
26
    seed: int = Field(default_factory=generate_random_seed, description="Random seed (auto-generated if not set)")
PengGao's avatar
PengGao committed
27
28
29

    def __init__(self, **data):
        super().__init__(**data)
30
        if not self.save_result_path:
PengGao's avatar
PengGao committed
31
            self.save_result_path = f"{self.task_id}"
PengGao's avatar
PengGao committed
32
33
34
35
36

    def get(self, key, default=None):
        return getattr(self, key, default)


PengGao's avatar
PengGao committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class VideoTaskRequest(BaseTaskRequest):
    num_fragments: int = Field(1, description="Number of fragments")
    target_video_length: int = Field(81, description="Target video length")
    audio_path: str = Field("", description="Input audio path (Wan-Audio)")
    video_duration: int = Field(5, description="Video duration (Wan-Audio)")
    talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")


class ImageTaskRequest(BaseTaskRequest):
    aspect_ratio: str = Field("16:9", description="Output aspect ratio")


class TaskRequest(BaseTaskRequest):
    num_fragments: int = Field(1, description="Number of fragments")
    target_video_length: int = Field(81, description="Target video length (video only)")
    audio_path: str = Field("", description="Input audio path (Wan-Audio)")
    video_duration: int = Field(5, description="Video duration (Wan-Audio)")
    talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
    aspect_ratio: str = Field("16:9", description="Output aspect ratio (T2I only)")


PengGao's avatar
PengGao committed
58
59
60
61
62
63
64
class TaskStatusMessage(BaseModel):
    task_id: str = Field(..., description="Task ID")


class TaskResponse(BaseModel):
    task_id: str
    task_status: str
65
    save_result_path: str
PengGao's avatar
PengGao committed
66
67
68
69
70


class StopTaskResponse(BaseModel):
    stop_status: str
    reason: str