api_server.py 1.49 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
import json

from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner


class Message(BaseModel):
    prompt: str
    negative_prompt: str = ""
    image_path: str = ""
    save_video_path: str

    def get(self, key, default=None):
        return getattr(self, key, default)


async def main(message):
    runner.set_inputs(message)
    runner.run_pipeline()
    return {"response": "finished"}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)
    parser.add_argument("--port", type=int, default=8000)
    args = parser.parse_args()
    print(f"args: {args}")

    with ProfilingContext("Init Server Cost"):
        config = set_config(args)
        print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
        runner = init_runner(config)

    app = FastAPI()

    @app.post("/v1/local/video/generate")
    async def generate_video(message: Message):
        response = await main(message)
        return response

    uvicorn.run(app, host="0.0.0.0", port=config.port)