"backends/vscode:/vscode.git/clone" did not exist on "2e4f4ba1bb0c656e19d01e268573fdbfaf5f7705"
Commit a852f879 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

[feature] support split server (#50)

* add load_transformer methods for split server

* add service utils

* [feature] support split servers
parent f2a586f1
{
"infer_steps": 20,
"target_video_length": 33,
"i2v_resolution": "720p",
"attention_type": "flash_attn3",
"seed": 0,
"sub_servers": {
"prompt_enhancer": ["http://localhost:9001"],
"text_encoders": ["http://localhost:9002"],
"vae_model": ["http://localhost:9004"]
}
}
{
"infer_steps": 20,
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"attention_type": "flash_attn3",
"seed": 42,
"sub_servers": {
"prompt_enhancer": ["http://localhost:9001"],
"text_encoders": ["http://localhost:9002"],
"vae_model": ["http://localhost:9004"]
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"sub_servers": {
"prompt_enhancer": ["http://localhost:9001"],
"text_encoders": ["http://localhost:9002"],
"image_encoder": ["http://localhost:9003"],
"vae_model": ["http://localhost:9004"]
}
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"sub_servers": {
"prompt_enhancer": ["http://localhost:9001"],
"text_encoders": ["http://localhost:9002"],
"image_encoder": ["http://localhost:9003"],
"vae_model": ["http://localhost:9004"]
}
}
import signal
import sys
import psutil
import asyncio
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
......@@ -17,32 +15,7 @@ import torch
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner
# =========================
# Signal & Process Control
# =========================
def kill_all_related_processes():
"""Kill the current process and all its child processes"""
current_process = psutil.Process()
children = current_process.children(recursive=True)
for child in children:
try:
child.kill()
except Exception as e:
logger.info(f"Failed to kill child process {child.pid}: {e}")
try:
current_process.kill()
except Exception as e:
logger.info(f"Failed to kill main process: {e}")
def signal_handler(sig, frame):
logger.info("\nReceived Ctrl+C, shutting down all related processes...")
kill_all_related_processes()
sys.exit(0)
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager
# =========================
......@@ -70,67 +43,8 @@ class Message(BaseModel):
return getattr(self, key, default)
class TaskStatusMessage(BaseModel):
task_id: str
class ServiceStatus:
_lock = threading.Lock()
_current_task = None
_result_store = {}
@classmethod
def start_task(cls, message: Message):
with cls._lock:
if cls._current_task is not None:
raise RuntimeError("Service busy")
if message.task_id_must_unique and message.task_id in cls._result_store:
raise RuntimeError(f"Task ID {message.task_id} already exists")
cls._current_task = {"message": message, "start_time": datetime.now()}
return message.task_id
@classmethod
def complete_task(cls, message: Message):
with cls._lock:
cls._result_store[message.task_id] = {"success": True, "message": message, "start_time": cls._current_task["start_time"], "completion_time": datetime.now()}
cls._current_task = None
@classmethod
def record_failed_task(cls, message: Message, error: Optional[str] = None):
"""Record a failed task with an error message."""
with cls._lock:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
if cls._current_task and cls._current_task["message"].task_id == task_id:
return {"task_status": "processing"}
if task_id in cls._result_store:
return {"task_status": "completed", **cls._result_store[task_id]}
return {"task_status": "not_found"}
@classmethod
def get_status_service(cls):
with cls._lock:
if cls._current_task:
return {"service_status": "busy", "task_id": cls._current_task["message"].task_id}
return {"service_status": "idle"}
@classmethod
def get_all_tasks(cls):
with cls._lock:
return cls._result_store
class ApiServerServiceStatus(BaseServiceStatus):
pass
def local_video_generate(message: Message):
......@@ -138,17 +52,22 @@ def local_video_generate(message: Message):
global runner
runner.set_inputs(message)
logger.info(f"message: {message}")
runner.run_pipeline()
ServiceStatus.complete_task(message)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(runner.run_pipeline())
finally:
loop.close()
ApiServerServiceStatus.complete_task(message)
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e))
ApiServerServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message):
try:
task_id = ServiceStatus.start_task(message)
task_id = ApiServerServiceStatus.start_task(message)
# Use background threads to perform long-running tasks
global thread
thread = threading.Thread(target=local_video_generate, args=(message,), daemon=True)
......@@ -160,17 +79,17 @@ async def v1_local_video_generate(message: Message):
@app.get("/v1/local/video/generate/service_status")
async def get_service_status():
return ServiceStatus.get_status_service()
return ApiServerServiceStatus.get_status_service()
@app.get("/v1/local/video/generate/get_all_tasks")
async def get_all_tasks():
return ServiceStatus.get_all_tasks()
return ApiServerServiceStatus.get_all_tasks()
@app.post("/v1/local/video/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return ServiceStatus.get_status_task_id(message.task_id)
return ApiServerServiceStatus.get_status_task_id(message.task_id)
def _async_raise(tid, exctype):
......@@ -193,7 +112,7 @@ async def stop_running_task():
# Clean up the thread reference
thread = None
ServiceStatus.clean_stopped_task()
ApiServerServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
return {"stop_status": "success", "reason": "Task stopped successfully."}
......@@ -208,13 +127,13 @@ async def stop_running_task():
# =========================
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--prompt_enhancer", default=None)
parser.add_argument("--split", action="store_true")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
......@@ -222,6 +141,7 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server" if args.split else "server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
......
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json
import os
import torch
import torchvision.transforms.functional as TF
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter
tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
img: bytes
def get(self, key, default=None):
return getattr(self, key, default)
class ImageEncoderServiceStatus(BaseServiceStatus):
pass
class ImageEncoderRunner:
def __init__(self, config):
self.config = config
self.image_encoder = self.get_image_encoder_model()
def get_image_encoder_model(self):
if "wan2.1" in self.config.model_cls:
image_encoder = CLIPModel(
dtype=torch.float16,
device="cuda",
checkpoint_path=os.path.join(
self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
)
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return image_encoder
def _run_image_encoder(self, img):
if "wan2.1" in self.config.model_cls:
img = image_transporter.load_image(img)
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return clip_encoder_out
def run_image_encoder(message: Message):
try:
global runner
image_encoder_out = runner._run_image_encoder(message.img)
assert image_encoder_out is not None
ImageEncoderServiceStatus.complete_task(message)
return image_encoder_out
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
ImageEncoderServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/image_encoder/generate")
def v1_local_image_encoder_generate(message: Message):
try:
task_id = ImageEncoderServiceStatus.start_task(message)
image_encoder_output = run_image_encoder(message)
output = tensor_transporter.prepare_tensor(image_encoder_output)
del image_encoder_output
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/image_encoder/generate/service_status")
async def get_service_status():
return ImageEncoderServiceStatus.get_status_service()
@app.get("/v1/local/image_encoder/generate/get_all_tasks")
async def get_all_tasks():
return ImageEncoderServiceStatus.get_all_tasks()
@app.post("/v1/local/image_encoder/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return ImageEncoderServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=9003)
args = parser.parse_args()
logger.info(f"args: {args}")
assert args.task == "i2v"
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = ImageEncoderRunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
import argparse
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json
import os
import torch
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter
tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
text: str
img: Optional[bytes] = None
def get(self, key, default=None):
return getattr(self, key, default)
class TextEncoderServiceStatus(BaseServiceStatus):
pass
class TextEncoderRunner:
def __init__(self, config):
self.config = config
self.text_encoders = self.get_text_encoder_model()
def get_text_encoder_model(self):
if "wan2.1" in self.config.model_cls:
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device="cuda",
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
elif self.config.model_cls in ["hunyuan"]:
if self.config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), "cuda")
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), "cuda")
text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), "cuda")
text_encoders = [text_encoder_1, text_encoder_2]
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return text_encoders
def _run_text_encoder(self, text, img):
if "wan2.1" in self.config.model_cls:
text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text], self.config)
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""], self.config)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
elif self.config.model_cls in ["hunyuan"]:
text_encoder_output = {}
for i, encoder in enumerate(self.text_encoders):
if self.config.task == "i2v" and i == 0:
img = image_transporter.load_image(img)
text_state, attention_mask = encoder.infer(text, img, self.config)
else:
text_state, attention_mask = encoder.infer(text, self.config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return text_encoder_output
def run_text_encoder(message: Message):
try:
global runner
text_encoder_output = runner._run_text_encoder(message.text, message.img)
TextEncoderServiceStatus.complete_task(message)
return text_encoder_output
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
TextEncoderServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/text_encoder/generate")
def v1_local_text_encoder_generate(message: Message):
try:
task_id = TextEncoderServiceStatus.start_task(message)
text_encoder_output = run_text_encoder(message)
output = tensor_transporter.prepare_tensor(text_encoder_output)
del text_encoder_output
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/text_encoder/generate/service_status")
async def get_service_status():
return TextEncoderServiceStatus.get_status_service()
@app.get("/v1/local/text_encoder/generate/get_all_tasks")
async def get_all_tasks():
return TextEncoderServiceStatus.get_all_tasks()
@app.post("/v1/local/text_encoder/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return TextEncoderServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=9002)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = TextEncoderRunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
from typing import Optional
import numpy as np
import uvicorn
import json
import os
import torch
import torchvision
import torchvision.transforms.functional as TF
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter
tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
img: Optional[bytes] = None
latents: Optional[bytes] = None
def get(self, key, default=None):
return getattr(self, key, default)
class VAEServiceStatus(BaseServiceStatus):
pass
class VAEEncoderRunner:
def __init__(self, config):
self.config = config
self.vae_model = self.get_vae_model()
def get_vae_model(self):
if "wan2.1" in self.config.model_cls:
vae_model = WanVAE(
vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
device="cuda",
parallel=self.config.parallel_vae,
)
elif self.config.model_cls in ["hunyuan"]:
vae_model = VideoEncoderKLCausal3DModel(model_path=self.config.model_path, dtype=torch.float16, device="cuda", config=self.config)
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return vae_model
def _run_vae_encoder(self, img):
img = image_transporter.load_image(img)
kwargs = {}
if "wan2.1" in self.config.model_cls:
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
h, w = img.shape[1:]
aspect_ratio = h / w
max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2]
self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h
self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
msk = torch.ones(1, self.config.target_video_length, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
vae_encode_out = self.vae_model.encode(
[
torch.concat(
[
torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 1, h, w),
],
dim=1,
).cuda()
],
self.config,
)[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
elif self.config.model_cls in ["hunyuan"]:
if self.config.i2v_resolution == "720p":
bucket_hw_base_size = 960
elif self.config.i2v_resolution == "540p":
bucket_hw_base_size = 720
elif self.config.i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"self.config.i2v_resolution: {self.config.i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = img.size
crop_size_list = HunyuanRunner.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = HunyuanRunner.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
self.config.target_height, self.config.target_width = closest_size
kwargs["target_height"], kwargs["target_width"] = closest_size
resize_param = min(closest_size)
center_crop_param = closest_size
ref_image_transform = torchvision.transforms.Compose(
[torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)
semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
vae_encode_out = self.vae_model.encode(semantic_image_pixel_values, self.config).mode()
scaling_factor = 0.476986
vae_encode_out.mul_(scaling_factor)
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return vae_encode_out, kwargs
def _run_vae_decoder(self, latents):
latents = tensor_transporter.load_tensor(latents)
images = self.vae_model.decode(latents, generator=None, config=self.config)
return images
def run_vae_encoder(message: Message):
try:
global runner
vae_encode_out, kwargs = runner._run_vae_encoder(message.img)
VAEServiceStatus.complete_task(message)
return vae_encode_out, kwargs
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
VAEServiceStatus.record_failed_task(message, error=str(e))
def run_vae_decoder(message: Message):
try:
global runner
images = runner._run_vae_decoder(message.latents)
VAEServiceStatus.complete_task(message)
return images
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
VAEServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/vae_model/encoder/generate")
def v1_local_vae_model_encoder_generate(message: Message):
try:
task_id = VAEServiceStatus.start_task(message)
vae_encode_out, kwargs = run_vae_encoder(message)
output = tensor_transporter.prepare_tensor(vae_encode_out)
del vae_encode_out
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": kwargs}
except RuntimeError as e:
return {"error": str(e)}
@app.post("/v1/local/vae_model/decoder/generate")
def v1_local_vae_model_decoder_generate(message: Message):
try:
task_id = VAEServiceStatus.start_task(message)
vae_decode_out = run_vae_decoder(message)
output = tensor_transporter.prepare_tensor(vae_decode_out)
del vae_decode_out
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/vae_model/encoder/generate/service_status")
async def get_service_status():
return VAEServiceStatus.get_status_service()
@app.get("/v1/local/vae_model/decoder/generate/service_status")
async def get_service_status():
return VAEServiceStatus.get_status_service()
@app.get("/v1/local/vae_model/encoder/generate/get_all_tasks")
async def get_all_tasks():
return VAEServiceStatus.get_all_tasks()
@app.get("/v1/local/vae_model/decoder/generate/get_all_tasks")
async def get_all_tasks():
return VAEServiceStatus.get_all_tasks()
@app.post("/v1/local/vae_model/encoder/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return VAEServiceStatus.get_status_task_id(message.task_id)
@app.post("/v1/local/vae_model/decoder/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return VAEServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=9004)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = VAEEncoderRunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
......@@ -50,6 +50,7 @@ if __name__ == "__main__":
with ProfilingContext("Total Cost"):
config = set_config(args)
config["mode"] = "infer"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
......
import asyncio
import gc
import aiohttp
import torch
import torch.distributed as dist
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.prompt_enhancer import PromptEnhancer
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.envs import *
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
from loguru import logger
class DefaultRunner:
def __init__(self, config):
self.config = config
if self.config.prompt_enhancer is not None and self.config.task == "t2v":
self.load_prompt_enhancer()
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
@ProfilingContext("Load prompt enhancer")
def load_prompt_enhancer(self):
gpu_count = torch.cuda.device_count()
if gpu_count == 1:
logger.info("Only one GPU, use prompt enhancer cpu offload")
raise NotImplementedError("prompt enhancer cpu offload is not supported.")
self.prompt_enhancer = PromptEnhancer(model_name=self.config.prompt_enhancer, device_map="cuda:1")
self.config["use_prompt_enhancer"] = True # Set use_prompt_enhancer to True now. (Default is False)
# TODO: implement prompt enhancer
self.has_prompt_enhancer = False
# if self.config.prompt_enhancer is not None and self.config.task == "t2v":
# self.config["use_prompt_enhancer"] = True
# self.has_prompt_enhancer = True
if self.config["mode"] == "split_server":
self.model = self.load_transformer()
self.text_encoders, self.vae_model, self.image_encoder = None, None, None
self.tensor_transporter = TensorTransporter()
self.image_transporter = ImageTransporter()
else:
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
def set_inputs(self, inputs):
self.config["prompt"] = inputs.get("prompt", "")
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side.
if self.has_prompt_enhancer and self.config["mode"] != "infer":
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side.
self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "")
def run_input_encoder(self):
async def post_encoders(self, prompt, img=None, i2v=False):
tasks = []
img_byte = self.image_transporter.prepare_image(img) if img is not None else None
if i2v:
if "wan2.1" in self.config["model_cls"]:
tasks.append(
asyncio.create_task(
self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda")
)
)
tasks.append(
asyncio.create_task(
self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda")
)
)
tasks.append(
asyncio.create_task(
self.post_task(task_type="text_encoder", urls=self.config["sub_servers"]["text_encoders"], message={"task_id": generate_task_id(), "text": prompt, "img": img_byte}, device="cuda")
)
)
results = await asyncio.gather(*tasks)
# clip_encoder, vae_encoder, text_encoder
if not i2v:
return None, None, results[0]
if "wan2.1" in self.config["model_cls"]:
return results[0], results[1], results[2]
else:
return None, results[0], results[1]
async def run_input_encoder(self):
image_encoder_output = None
if self.config["task"] == "i2v":
with ProfilingContext("Run Img Encoder"):
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
with ProfilingContext("Run Text Encoder"):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, self.text_encoders, self.config, image_encoder_output)
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
i2v = self.config["task"] == "i2v"
img = Image.open(self.config["image_path"]).convert("RGB") if i2v else None
with ProfilingContext("Run Encoders"):
if self.config["mode"] == "split_server":
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders(prompt, img, i2v)
if i2v:
if self.config["model_cls"] in ["hunyuan"]:
image_encoder_output = {"img": img, "img_latents": vae_encode_out}
elif "wan2.1" in self.config["model_cls"]:
image_encoder_output = {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
else:
raise ValueError(f"Unsupported model class: {self.config['model_cls']}")
else:
if i2v:
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
text_encoder_output = self.run_text_encoder(prompt, self.text_encoders, self.config, image_encoder_output)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
......@@ -60,9 +106,9 @@ class DefaultRunner:
return self.model.scheduler.latents, self.model.scheduler.generator
def run_step(self, step_index=0):
async def run_step(self, step_index=0):
self.init_scheduler()
self.run_input_encoder()
await self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.model.scheduler.step_pre(step_index=step_index)
self.model.infer(self.inputs)
......@@ -74,8 +120,16 @@ class DefaultRunner:
torch.cuda.empty_cache()
@ProfilingContext("Run VAE")
def run_vae(self, latents, generator):
images = self.vae_model.decode(latents, generator=generator, config=self.config)
async def run_vae(self, latents, generator):
if self.config["mode"] == "split_server":
images = await self.post_task(
task_type="vae_model/decoder",
urls=self.config["sub_servers"]["vae_model"],
message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
device="cpu",
)
else:
images = self.vae_model.decode(latents, generator=generator, config=self.config)
return images
@ProfilingContext("Save video")
......@@ -86,15 +140,30 @@ class DefaultRunner:
else:
save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))
def run_pipeline(self):
async def post_task(self, task_type, urls, message, device="cuda"):
while True:
for url in urls:
async with aiohttp.ClientSession() as session:
async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
status = await response.json()
if status["service_status"] == "idle":
async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
result = await response.json()
if result["kwargs"] is not None:
for k, v in result["kwargs"].items():
setattr(self.config, k, v)
return self.tensor_transporter.load_tensor(result["output"], device)
await asyncio.sleep(0.1)
async def run_pipeline(self):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.prompt_enhancer(self.config["prompt"])
self.init_scheduler()
self.run_input_encoder()
await self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
latents, generator = self.run()
self.end_run()
images = self.run_vae(latents, generator)
images = await self.run_vae(latents, generator)
self.save_video(images)
del latents, generator, images
gc.collect()
......
......@@ -21,6 +21,13 @@ class HunyuanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
return HunyuanModel(self.config.model_path, self.config, init_device, self.config)
@ProfilingContext("Load models")
def load_model(self):
if self.config["parallel_attn_type"]:
......@@ -64,7 +71,8 @@ class HunyuanRunner(DefaultRunner):
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
return text_encoder_output
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
@staticmethod
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
......@@ -79,7 +87,8 @@ class HunyuanRunner(DefaultRunner):
return closest_size, closest_ratio
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
@staticmethod
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
......
......@@ -29,6 +29,13 @@ class WanCausVidRunner(WanRunner):
self.infer_blocks = self.model.config.num_blocks
self.num_fragments = self.model.config.num_fragments
def load_transformer(self):
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
return WanCausVidModel(self.config.model_path, self.config, init_device)
@ProfilingContext("Load models")
def load_model(self):
if self.config["parallel_attn_type"]:
......
......@@ -25,6 +25,19 @@ class WanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
model = WanModel(self.config.model_path, self.config, init_device)
if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
return model
@ProfilingContext("Load models")
def load_model(self):
if self.config["parallel_attn_type"]:
......
import sys
import psutil
import signal
import base64
from PIL import Image
from loguru import logger
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel
import threading
import torch
import io
class ProcessManager:
@staticmethod
def kill_all_related_processes():
"""Kill the current process and all its child processes"""
current_process = psutil.Process()
children = current_process.children(recursive=True)
for child in children:
try:
child.kill()
except Exception as e:
logger.info(f"Failed to kill child process {child.pid}: {e}")
try:
current_process.kill()
except Exception as e:
logger.info(f"Failed to kill main process: {e}")
@staticmethod
def signal_handler(sig, frame):
logger.info("\nReceived Ctrl+C, shutting down all related processes...")
ProcessManager.kill_all_related_processes()
sys.exit(0)
@staticmethod
def register_signal_handler():
"""Register the signal handler for SIGINT"""
signal.signal(signal.SIGINT, ProcessManager.signal_handler)
class TaskStatusMessage(BaseModel):
task_id: str
class BaseServiceStatus:
_lock = threading.Lock()
_current_task = None
_result_store = {}
@classmethod
def start_task(cls, message):
with cls._lock:
if cls._current_task is not None:
raise RuntimeError("Service busy")
if message.task_id_must_unique and message.task_id in cls._result_store:
raise RuntimeError(f"Task ID {message.task_id} already exists")
cls._current_task = {"message": message, "start_time": datetime.now()}
return message.task_id
@classmethod
def complete_task(cls, message):
with cls._lock:
cls._result_store[message.task_id] = {"success": True, "message": message, "start_time": cls._current_task["start_time"], "completion_time": datetime.now()}
cls._current_task = None
@classmethod
def record_failed_task(cls, message, error: Optional[str] = None):
"""Record a failed task with an error message."""
with cls._lock:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
if cls._current_task and cls._current_task["message"].task_id == task_id:
return {"task_status": "processing"}
if task_id in cls._result_store:
return {"task_status": "completed", **cls._result_store[task_id]}
return {"task_status": "not_found"}
@classmethod
def get_status_service(cls):
with cls._lock:
if cls._current_task:
return {"service_status": "busy", "task_id": cls._current_task["message"].task_id}
return {"service_status": "idle"}
@classmethod
def get_all_tasks(cls):
with cls._lock:
return cls._result_store
class TensorTransporter:
def __init__(self):
self.buffer = io.BytesIO()
def to_device(self, data, device):
if isinstance(data, dict):
return {key: self.to_device(value, device) for key, value in data.items()}
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
def prepare_tensor(self, data: torch.Tensor) -> bytes:
self.buffer.seek(0)
self.buffer.truncate()
torch.save(self.to_device(data, "cpu"), self.buffer)
return base64.b64encode(self.buffer.getvalue()).decode("utf-8")
def load_tensor(self, tensor_base64: str, device="cuda") -> torch.Tensor:
tensor_bytes = base64.b64decode(tensor_base64)
with io.BytesIO(tensor_bytes) as buffer:
return self.to_device(torch.load(buffer), device)
class ImageTransporter:
def __init__(self):
self.buffer = io.BytesIO()
def prepare_image(self, image: Image.Image) -> bytes:
self.buffer.seek(0)
self.buffer.truncate()
image.save(self.buffer, format="PNG")
return base64.b64encode(self.buffer.getvalue()).decode("utf-8")
def load_image(self, image_base64: bytes) -> Image.Image:
image_bytes = base64.b64decode(image_base64)
with io.BytesIO(image_bytes) as buffer:
return Image.open(buffer).convert("RGB")
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.common.apis.image_encoder \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/deploy/wan_i2v.json \
--port 9003
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.api_server \
--model_cls wan2.1 \
--task i2v \
--split \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/deploy/wan_i2v.json \
--port 8000
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.common.apis.text_encoder \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/deploy/wan_i2v.json \
--port 9002
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.common.apis.vae \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/deploy/wan_i2v.json \
--port 9004
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment