text_encoder.py 4.22 KB
Newer Older
1
2
3
import argparse
import json
import os
PengGao's avatar
PengGao committed
4
5
from typing import Optional

6
import torch
PengGao's avatar
PengGao committed
7
8
9
10
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
11

12
13
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
PengGao's avatar
PengGao committed
14
15
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
16
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
17
from lightx2v.utils.profiler import ProfilingContext
PengGao's avatar
PengGao committed
18
19
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from lightx2v.utils.set_config import set_config

tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()

# =========================
# FastAPI Related Code
# =========================

runner = None

app = FastAPI()


class Message(BaseModel):
    task_id: str
    task_id_must_unique: bool = False

    text: str
    img: Optional[bytes] = None
40
    n_prompt: Optional[str] = None
41
42
43
44
45
46
47
48
49
50
51
52

    def get(self, key, default=None):
        return getattr(self, key, default)


class TextEncoderServiceStatus(BaseServiceStatus):
    pass


class TextEncoderRunner:
    def __init__(self, config):
        self.config = config
53
54
55
        self.runner_cls = RUNNER_REGISTER[self.config.model_cls]

        self.runner = self.runner_cls(config)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
56
        self.runner.text_encoders = self.runner.load_text_encoder()
57

58
    def _run_text_encoder(self, text, img, n_prompt):
59
60
61
62
        if img is not None:
            img = image_transporter.load_image(img)
        self.runner.config["negative_prompt"] = n_prompt
        text_encoder_output = self.runner.run_text_encoder(text, img)
63
64
65
66
67
68
        return text_encoder_output


def run_text_encoder(message: Message):
    try:
        global runner
69
        text_encoder_output = runner._run_text_encoder(message.text, message.img, message.n_prompt)
70
71
72
73
74
75
76
        TextEncoderServiceStatus.complete_task(message)
        return text_encoder_output
    except Exception as e:
        logger.error(f"task_id {message.task_id} failed: {str(e)}")
        TextEncoderServiceStatus.record_failed_task(message, error=str(e))


77
@app.post("/v1/local/text_encoders/generate")
78
79
80
81
82
83
84
85
86
87
88
def v1_local_text_encoder_generate(message: Message):
    try:
        task_id = TextEncoderServiceStatus.start_task(message)
        text_encoder_output = run_text_encoder(message)
        output = tensor_transporter.prepare_tensor(text_encoder_output)
        del text_encoder_output
        return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
    except RuntimeError as e:
        return {"error": str(e)}


89
@app.get("/v1/local/text_encoders/generate/service_status")
90
91
92
93
async def get_service_status():
    return TextEncoderServiceStatus.get_status_service()


94
@app.get("/v1/local/text_encoders/generate/get_all_tasks")
95
96
97
98
async def get_all_tasks():
    return TextEncoderServiceStatus.get_all_tasks()


99
@app.post("/v1/local/text_encoders/generate/task_status")
100
101
102
103
104
105
106
107
108
109
110
async def get_task_status(message: TaskStatusMessage):
    return TextEncoderServiceStatus.get_status_task_id(message.task_id)


# =========================
# Main Entry
# =========================

if __name__ == "__main__":
    ProcessManager.register_signal_handler()
    parser = argparse.ArgumentParser()
111
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)

    parser.add_argument("--port", type=int, default=9002)
    args = parser.parse_args()
    logger.info(f"args: {args}")

    with ProfilingContext("Init Server Cost"):
        config = set_config(args)
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
        runner = TextEncoderRunner(config)

    uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)