image_encoder.py 4.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json
import os
import torch
import torchvision.transforms.functional as TF

from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel

from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter

tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()

# =========================
# FastAPI Related Code
# =========================

runner = None

app = FastAPI()


class Message(BaseModel):
    task_id: str
    task_id_must_unique: bool = False

    img: bytes

    def get(self, key, default=None):
        return getattr(self, key, default)


class ImageEncoderServiceStatus(BaseServiceStatus):
    pass


class ImageEncoderRunner:
    def __init__(self, config):
        self.config = config
        self.image_encoder = self.get_image_encoder_model()

    def get_image_encoder_model(self):
        if "wan2.1" in self.config.model_cls:
            image_encoder = CLIPModel(
                dtype=torch.float16,
                device="cuda",
                checkpoint_path=os.path.join(
                    self.config.model_path,
                    "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
                ),
                tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
            )
        else:
            raise ValueError(f"Unsupported model class: {self.config.model_cls}")
        return image_encoder

    def _run_image_encoder(self, img):
        if "wan2.1" in self.config.model_cls:
            img = image_transporter.load_image(img)
            img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
            clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
        else:
            raise ValueError(f"Unsupported model class: {self.config.model_cls}")
        return clip_encoder_out


def run_image_encoder(message: Message):
    try:
        global runner
        image_encoder_out = runner._run_image_encoder(message.img)
        assert image_encoder_out is not None
        ImageEncoderServiceStatus.complete_task(message)
        return image_encoder_out
    except Exception as e:
        logger.error(f"task_id {message.task_id} failed: {str(e)}")
        ImageEncoderServiceStatus.record_failed_task(message, error=str(e))


@app.post("/v1/local/image_encoder/generate")
def v1_local_image_encoder_generate(message: Message):
    try:
        task_id = ImageEncoderServiceStatus.start_task(message)
        image_encoder_output = run_image_encoder(message)
        output = tensor_transporter.prepare_tensor(image_encoder_output)
        del image_encoder_output
        return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
    except RuntimeError as e:
        return {"error": str(e)}


@app.get("/v1/local/image_encoder/generate/service_status")
async def get_service_status():
    return ImageEncoderServiceStatus.get_status_service()


@app.get("/v1/local/image_encoder/generate/get_all_tasks")
async def get_all_tasks():
    return ImageEncoderServiceStatus.get_all_tasks()


@app.post("/v1/local/image_encoder/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
    return ImageEncoderServiceStatus.get_status_task_id(message.task_id)


# =========================
# Main Entry
# =========================

if __name__ == "__main__":
    ProcessManager.register_signal_handler()
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)

    parser.add_argument("--port", type=int, default=9003)
    args = parser.parse_args()
    logger.info(f"args: {args}")

    assert args.task == "i2v"

    with ProfilingContext("Init Server Cost"):
        config = set_config(args)
        config["mode"] = "split_server"
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
        runner = ImageEncoderRunner(config)

    uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)