Commit b7d2d43f authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

support split server for dit module (#58)

* split dit server from default runner

* split dit server from default runner

* update loading functions

* simplify loader functions and runner functions

* simplify code && split dit service

* simplify code && split dit service

* support split server for cogvideox

* clear code.
parent 2bb1b0f0
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
"attention_type": "flash_attn3", "attention_type": "flash_attn3",
"seed": 0, "seed": 0,
"sub_servers": { "sub_servers": {
"dit": ["http://localhost:9000"],
"prompt_enhancer": ["http://localhost:9001"], "prompt_enhancer": ["http://localhost:9001"],
"image_encoder": ["http://localhost:9003"],
"text_encoders": ["http://localhost:9002"], "text_encoders": ["http://localhost:9002"],
"vae_model": ["http://localhost:9004"] "vae_model": ["http://localhost:9004"]
} }
......
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
"attention_type": "flash_attn3", "attention_type": "flash_attn3",
"seed": 42, "seed": 42,
"sub_servers": { "sub_servers": {
"dit": ["http://localhost:9000"],
"prompt_enhancer": ["http://localhost:9001"], "prompt_enhancer": ["http://localhost:9001"],
"image_encoder": ["http://localhost:9003"],
"text_encoders": ["http://localhost:9002"], "text_encoders": ["http://localhost:9002"],
"vae_model": ["http://localhost:9004"] "vae_model": ["http://localhost:9004"]
} }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": false, "cpu_offload": false,
"sub_servers": { "sub_servers": {
"dit": ["http://localhost:9000"],
"prompt_enhancer": ["http://localhost:9001"], "prompt_enhancer": ["http://localhost:9001"],
"text_encoders": ["http://localhost:9002"], "text_encoders": ["http://localhost:9002"],
"image_encoder": ["http://localhost:9003"], "image_encoder": ["http://localhost:9003"],
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": false, "cpu_offload": false,
"sub_servers": { "sub_servers": {
"dit": ["http://localhost:9000"],
"prompt_enhancer": ["http://localhost:9001"], "prompt_enhancer": ["http://localhost:9001"],
"text_encoders": ["http://localhost:9002"], "text_encoders": ["http://localhost:9002"],
"image_encoder": ["http://localhost:9003"], "image_encoder": ["http://localhost:9003"],
......
...@@ -129,7 +129,7 @@ async def stop_running_task(): ...@@ -129,7 +129,7 @@ async def stop_running_task():
if __name__ == "__main__": if __name__ == "__main__":
ProcessManager.register_signal_handler() ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
......
import argparse
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json
import os
import torch
from lightx2v.common.ops import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager, TensorTransporter, ImageTransporter
tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
inputs: bytes
kwargs: bytes
def get(self, key, default=None):
return getattr(self, key, default)
class DiTServiceStatus(BaseServiceStatus):
pass
class DiTRunner:
def __init__(self, config):
self.config = config
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.model = self.runner.load_transformer(self.runner.get_init_device())
def _run_dit(self, inputs, kwargs):
self.runner.config.update(tensor_transporter.load_tensor(kwargs))
self.runner.inputs = tensor_transporter.load_tensor(inputs)
self.runner.init_scheduler()
self.runner.model.scheduler.prepare(self.runner.inputs["image_encoder_output"])
latents, _ = self.runner.run()
self.runner.end_run()
return latents
def run_dit(message: Message):
try:
global runner
dit_output = runner._run_dit(message.inputs, message.kwargs)
DiTServiceStatus.complete_task(message)
return dit_output
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
DiTServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/dit/generate")
def v1_local_dit_generate(message: Message):
try:
task_id = DiTServiceStatus.start_task(message)
dit_output = run_dit(message)
output = tensor_transporter.prepare_tensor(dit_output)
del dit_output
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/dit/generate/service_status")
async def get_service_status():
return DiTServiceStatus.get_status_service()
@app.get("/v1/local/dit/generate/get_all_tasks")
async def get_all_tasks():
return DiTServiceStatus.get_all_tasks()
@app.post("/v1/local/dit/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return DiTServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=9000)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = DiTRunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
...@@ -8,7 +8,11 @@ import os ...@@ -8,7 +8,11 @@ import os
import torch import torch
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
...@@ -43,38 +47,20 @@ class ImageEncoderServiceStatus(BaseServiceStatus): ...@@ -43,38 +47,20 @@ class ImageEncoderServiceStatus(BaseServiceStatus):
class ImageEncoderRunner: class ImageEncoderRunner:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.image_encoder = self.get_image_encoder_model() self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
def get_image_encoder_model(self): self.runner = self.runner_cls(config)
if "wan2.1" in self.config.model_cls: self.runner.image_encoder = self.runner.load_image_encoder(self.runner.get_init_device())
image_encoder = CLIPModel(
dtype=torch.float16,
device="cuda",
checkpoint_path=os.path.join(
self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
)
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return image_encoder
def _run_image_encoder(self, img): def _run_image_encoder(self, img):
if "wan2.1" in self.config.model_cls: img = image_transporter.load_image(img)
img = image_transporter.load_image(img) return self.runner.run_image_encoder(img)
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return clip_encoder_out
def run_image_encoder(message: Message): def run_image_encoder(message: Message):
try: try:
global runner global runner
image_encoder_out = runner._run_image_encoder(message.img) image_encoder_out = runner._run_image_encoder(message.img)
assert image_encoder_out is not None
ImageEncoderServiceStatus.complete_task(message) ImageEncoderServiceStatus.complete_task(message)
return image_encoder_out return image_encoder_out
except Exception as e: except Exception as e:
...@@ -116,7 +102,7 @@ async def get_task_status(message: TaskStatusMessage): ...@@ -116,7 +102,7 @@ async def get_task_status(message: TaskStatusMessage):
if __name__ == "__main__": if __name__ == "__main__":
ProcessManager.register_signal_handler() ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
......
...@@ -8,10 +8,11 @@ import json ...@@ -8,10 +8,11 @@ import json
import os import os
import torch import torch
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
...@@ -48,49 +49,16 @@ class TextEncoderServiceStatus(BaseServiceStatus): ...@@ -48,49 +49,16 @@ class TextEncoderServiceStatus(BaseServiceStatus):
class TextEncoderRunner: class TextEncoderRunner:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.text_encoders = self.get_text_encoder_model() self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
def get_text_encoder_model(self): self.runner = self.runner_cls(config)
if "wan2.1" in self.config.model_cls: self.runner.text_encoders = self.runner.load_text_encoder(self.runner.get_init_device())
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device="cuda",
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
elif self.config.model_cls in ["hunyuan"]:
if self.config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), "cuda")
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), "cuda")
text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), "cuda")
text_encoders = [text_encoder_1, text_encoder_2]
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return text_encoders
def _run_text_encoder(self, text, img, n_prompt): def _run_text_encoder(self, text, img, n_prompt):
if "wan2.1" in self.config.model_cls: if img is not None:
text_encoder_output = {} img = image_transporter.load_image(img)
context = self.text_encoders[0].infer([text]) self.runner.config["negative_prompt"] = n_prompt
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""]) text_encoder_output = self.runner.run_text_encoder(text, img)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
elif self.config.model_cls in ["hunyuan"]:
text_encoder_output = {}
for i, encoder in enumerate(self.text_encoders):
if self.config.task == "i2v" and i == 0:
img = image_transporter.load_image(img)
text_state, attention_mask = encoder.infer(text, img, self.config)
else:
text_state, attention_mask = encoder.infer(text, self.config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return text_encoder_output return text_encoder_output
...@@ -139,7 +107,7 @@ async def get_task_status(message: TaskStatusMessage): ...@@ -139,7 +107,7 @@ async def get_task_status(message: TaskStatusMessage):
if __name__ == "__main__": if __name__ == "__main__":
ProcessManager.register_signal_handler() ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
......
...@@ -10,10 +10,13 @@ import os ...@@ -10,10 +10,13 @@ import os
import torch import torch
import torchvision import torchvision
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from lightx2v.common.ops import *
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
...@@ -46,98 +49,22 @@ class VAEServiceStatus(BaseServiceStatus): ...@@ -46,98 +49,22 @@ class VAEServiceStatus(BaseServiceStatus):
pass pass
class VAEEncoderRunner: class VAERunner:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.vae_model = self.get_vae_model() self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
def get_vae_model(self): self.runner = self.runner_cls(config)
if "wan2.1" in self.config.model_cls: self.runner.vae_encoder, self.runner.vae_decoder = self.runner.load_vae(self.runner.get_init_device())
vae_model = WanVAE(
vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
device="cuda",
parallel=self.config.parallel_vae,
)
elif self.config.model_cls in ["hunyuan"]:
vae_model = VideoEncoderKLCausal3DModel(model_path=self.config.model_path, dtype=torch.float16, device="cuda", config=self.config)
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return vae_model
def _run_vae_encoder(self, img): def _run_vae_encoder(self, img):
img = image_transporter.load_image(img) img = image_transporter.load_image(img)
kwargs = {} vae_encode_out, kwargs = self.runner.run_vae_encoder(img)
if "wan2.1" in self.config.model_cls:
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
h, w = img.shape[1:]
aspect_ratio = h / w
max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2]
self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h
self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
msk = torch.ones(1, self.config.target_video_length, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
vae_encode_out = self.vae_model.encode(
[
torch.concat(
[
torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 1, h, w),
],
dim=1,
).cuda()
],
self.config,
)[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
elif self.config.model_cls in ["hunyuan"]:
if self.config.i2v_resolution == "720p":
bucket_hw_base_size = 960
elif self.config.i2v_resolution == "540p":
bucket_hw_base_size = 720
elif self.config.i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"self.config.i2v_resolution: {self.config.i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = img.size
crop_size_list = HunyuanRunner.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = HunyuanRunner.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
self.config.target_height, self.config.target_width = closest_size
kwargs["target_height"], kwargs["target_width"] = closest_size
resize_param = min(closest_size)
center_crop_param = closest_size
ref_image_transform = torchvision.transforms.Compose(
[torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)
semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
vae_encode_out = self.vae_model.encode(semantic_image_pixel_values, self.config).mode()
scaling_factor = 0.476986
vae_encode_out.mul_(scaling_factor)
else:
raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return vae_encode_out, kwargs return vae_encode_out, kwargs
def _run_vae_decoder(self, latents): def _run_vae_decoder(self, latents):
latents = tensor_transporter.load_tensor(latents) latents = tensor_transporter.load_tensor(latents)
images = self.vae_model.decode(latents, generator=None, config=self.config) images = self.runner.vae_decoder.decode(latents, generator=None, config=self.config)
return images return images
...@@ -229,7 +156,7 @@ async def get_task_status(message: TaskStatusMessage): ...@@ -229,7 +156,7 @@ async def get_task_status(message: TaskStatusMessage):
if __name__ == "__main__": if __name__ == "__main__":
ProcessManager.register_signal_handler() ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
...@@ -242,6 +169,6 @@ if __name__ == "__main__": ...@@ -242,6 +169,6 @@ if __name__ == "__main__":
config = set_config(args) config = set_config(args)
config["mode"] = "split_server" config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = VAEEncoderRunner(config) runner = VAERunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1) uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
...@@ -32,6 +32,7 @@ def init_runner(config): ...@@ -32,6 +32,7 @@ def init_runner(config):
runner = GraphRunner(default_runner) runner = GraphRunner(default_runner)
else: else:
runner = RUNNER_REGISTER[config.model_cls](config) runner = RUNNER_REGISTER[config.model_cls](config)
runner.init_modules()
return runner return runner
......
...@@ -16,45 +16,67 @@ class CogvideoxRunner(DefaultRunner): ...@@ -16,45 +16,67 @@ class CogvideoxRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ProfilingContext("Load models") def load_transformer(self, init_device):
def load_model(self): model = CogvideoxModel(self.config)
return model
def load_image_encoder(self, init_device):
return None
def load_text_encoder(self, init_device):
text_encoder = T5EncoderModel_v1_1_xxl(self.config) text_encoder = T5EncoderModel_v1_1_xxl(self.config)
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = CogvideoxModel(self.config) return text_encoders
def load_vae(self, init_device):
vae_model = CogvideoxVAE(self.config) vae_model = CogvideoxVAE(self.config)
image_encoder = None return vae_model, vae_model
return model, text_encoders, vae_model, image_encoder
def init_scheduler(self): def init_scheduler(self):
scheduler = CogvideoxXDPMScheduler(self.config) scheduler = CogvideoxXDPMScheduler(self.config)
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, text_encoders, config, image_encoder_output): def run_text_encoder(self, text, img):
text_encoder_output = {} text_encoder_output = {}
n_prompt = config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
context = text_encoders[0].infer([text], config) context = self.text_encoders[0].infer([text], self.config)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config) context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""], self.config)
text_encoder_output["context"] = context text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null text_encoder_output["context_null"] = context_null
return text_encoder_output return text_encoder_output
def run_vae_encoder(self, img):
# TODO: implement vae encoder for Cogvideox
raise NotImplementedError("I2V inference is not implemented for Cogvideox.")
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
# TODO: Implement image encoder for Cogvideox-I2V
raise ValueError(f"Unsupported model class: {self.config['model_cls']}")
def set_target_shape(self): def set_target_shape(self):
num_frames = self.config.target_video_length ret = {}
latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1 if self.config.task == "i2v":
additional_frames = 0 # TODO: implement set_target_shape for Cogvideox-I2V
patch_size_t = self.config.patch_size_t raise NotImplementedError("I2V inference is not implemented for Cogvideox.")
if patch_size_t is not None and latent_frames % patch_size_t != 0: else:
additional_frames = patch_size_t - latent_frames % patch_size_t num_frames = self.config.target_video_length
num_frames += additional_frames * self.config.vae_scale_factor_temporal latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
self.config.target_shape = ( additional_frames = 0
self.config.batch_size, patch_size_t = self.config.patch_size_t
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1, if patch_size_t is not None and latent_frames % patch_size_t != 0:
self.config.latent_channels, additional_frames = patch_size_t - latent_frames % patch_size_t
self.config.height // self.config.vae_scale_factor_spatial, num_frames += additional_frames * self.config.vae_scale_factor_temporal
self.config.width // self.config.vae_scale_factor_spatial, self.config.target_shape = (
) self.config.batch_size,
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
def save_video(self, images): self.config.latent_channels,
self.config.height // self.config.vae_scale_factor_spatial,
self.config.width // self.config.vae_scale_factor_spatial,
)
ret["target_shape"] = self.config.target_shape
return ret
def save_video_func(self, images):
with imageio.get_writer(self.config.save_video_path, fps=16) as writer: with imageio.get_writer(self.config.save_video_path, fps=16) as writer:
for pil_image in images: for pil_image in images:
frame_np = np.array(pil_image, dtype=np.uint8) frame_np = np.array(pil_image, dtype=np.uint8)
......
...@@ -25,19 +25,51 @@ class DefaultRunner: ...@@ -25,19 +25,51 @@ class DefaultRunner:
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
logger.warning("No prompt enhancer server available, disable prompt enhancer.") logger.warning("No prompt enhancer server available, disable prompt enhancer.")
def init_modules(self):
if self.config["mode"] == "split_server": if self.config["mode"] == "split_server":
self.model = self.load_transformer()
self.text_encoders, self.vae_model, self.image_encoder = None, None, None
self.tensor_transporter = TensorTransporter() self.tensor_transporter = TensorTransporter()
self.image_transporter = ImageTransporter() self.image_transporter = ImageTransporter()
if not self.check_sub_servers("dit"):
raise ValueError("No dit server available")
if not self.check_sub_servers("text_encoders"): if not self.check_sub_servers("text_encoders"):
raise ValueError("No text encoder server available") raise ValueError("No text encoder server available")
if "wan2.1" in self.config["model_cls"] and not self.check_sub_servers("image_encoder"): if self.config["task"] == "i2v":
raise ValueError("No image encoder server available") if not self.check_sub_servers("image_encoder"):
raise ValueError("No image encoder server available")
if not self.check_sub_servers("vae_model"): if not self.check_sub_servers("vae_model"):
raise ValueError("No vae model server available") raise ValueError("No vae server available")
self.run_dit = self.run_dit_server
self.run_vae_decoder = self.run_vae_decoder_server
if self.config["task"] == "i2v":
self.run_input_encoder = self.run_input_encoder_server_i2v
else:
self.run_input_encoder = self.run_input_encoder_server_t2v
else:
self.load_model()
self.run_dit = self.run_dit_local
self.run_vae_decoder = self.run_vae_decoder_local
if self.config["task"] == "i2v":
self.run_input_encoder = self.run_input_encoder_local_i2v
else:
self.run_input_encoder = self.run_input_encoder_local_t2v
def get_init_device(self):
if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
if self.config.cpu_offload:
init_device = torch.device("cpu")
else: else:
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model() init_device = torch.device("cuda")
return init_device
@ProfilingContext("Load models")
def load_model(self):
init_device = self.get_init_device()
self.text_encoders = self.load_text_encoder(init_device)
self.model = self.load_transformer(init_device)
self.image_encoder = self.load_image_encoder(init_device)
self.vae_encoder, self.vae_decoder = self.load_vae(init_device)
def check_sub_servers(self, task_type): def check_sub_servers(self, task_type):
urls = self.config.get("sub_servers", {}).get(task_type, []) urls = self.config.get("sub_servers", {}).get(task_type, [])
...@@ -66,76 +98,6 @@ class DefaultRunner: ...@@ -66,76 +98,6 @@ class DefaultRunner:
self.config["image_path"] = inputs.get("image_path", "") self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "") self.config["save_video_path"] = inputs.get("save_video_path", "")
def post_prompt_enhancer(self):
while True:
for url in self.config["sub_servers"]["prompt_enhancer"]:
response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
if response["service_status"] == "idle":
response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
self.config["prompt_enhanced"] = response.json()["output"]
logger.info(f"Enhanced prompt: {self.config['prompt_enhanced']}")
return
async def post_encoders(self, prompt, img=None, n_prompt=None, i2v=False):
tasks = []
img_byte = self.image_transporter.prepare_image(img) if img is not None else None
if i2v:
if "wan2.1" in self.config["model_cls"]:
tasks.append(
asyncio.create_task(
self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda")
)
)
tasks.append(
asyncio.create_task(
self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda")
)
)
tasks.append(
asyncio.create_task(
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
device="cuda",
)
)
)
results = await asyncio.gather(*tasks)
# clip_encoder, vae_encoder, text_encoders
if not i2v:
return None, None, results[0]
if "wan2.1" in self.config["model_cls"]:
return results[0], results[1], results[2]
else:
return None, results[0], results[1]
async def run_input_encoder(self):
image_encoder_output = None
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
i2v = self.config["task"] == "i2v"
img = Image.open(self.config["image_path"]).convert("RGB") if i2v else None
with ProfilingContext("Run Encoders"):
if self.config["mode"] == "split_server":
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders(prompt, img, n_prompt, i2v)
if i2v:
if self.config["model_cls"] in ["hunyuan"]:
image_encoder_output = {"img": img, "img_latents": vae_encode_out}
elif "wan2.1" in self.config["model_cls"]:
image_encoder_output = {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
else:
raise ValueError(f"Unsupported model class: {self.config['model_cls']}")
else:
if i2v:
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
text_encoder_output = self.run_text_encoder(prompt, self.text_encoders, self.config, image_encoder_output)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
gc.collect()
torch.cuda.empty_cache()
def run(self): def run(self):
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
...@@ -164,26 +126,38 @@ class DefaultRunner: ...@@ -164,26 +126,38 @@ class DefaultRunner:
del self.inputs, self.model.scheduler del self.inputs, self.model.scheduler
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ProfilingContext("Run VAE") @ProfilingContext("Run Encoders")
async def run_vae(self, latents, generator): async def run_input_encoder_local_i2v(self):
if self.config["mode"] == "split_server": prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
images = await self.post_task( img = Image.open(self.config["image_path"]).convert("RGB")
task_type="vae_model/decoder", clip_encoder_out = self.run_image_encoder(img)
urls=self.config["sub_servers"]["vae_model"], vae_encode_out, kwargs = self.run_vae_encoder(img)
message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)}, text_encoder_output = self.run_text_encoder(prompt, img)
device="cpu", return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
)
else: @ProfilingContext("Run Encoders")
images = self.vae_model.decode(latents, generator=generator, config=self.config) async def run_input_encoder_local_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, None)
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
@ProfilingContext("Run DiT")
async def run_dit_local(self, kwargs):
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
latents, generator = self.run()
self.end_run()
return latents, generator
@ProfilingContext("Run VAE Decoder")
async def run_vae_decoder_local(self, latents, generator):
images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
return images return images
@ProfilingContext("Save video") @ProfilingContext("Save video")
def save_video(self, images): def save_video(self, images):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0): if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
if self.config.model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_skyreels_v2_df"]: self.save_video_func(images)
cache_video(tensor=images, save_file=self.config.save_video_path, fps=self.config.get("fps", 16), nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))
async def post_task(self, task_type, urls, message, device="cuda"): async def post_task(self, task_type, urls, message, device="cuda"):
while True: while True:
...@@ -200,15 +174,95 @@ class DefaultRunner: ...@@ -200,15 +174,95 @@ class DefaultRunner:
return self.tensor_transporter.load_tensor(result["output"], device) return self.tensor_transporter.load_tensor(result["output"], device)
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
def post_prompt_enhancer(self):
while True:
for url in self.config["sub_servers"]["prompt_enhancer"]:
response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
if response["service_status"] == "idle":
response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
enhanced_prompt = response.json()["output"]
logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
async def post_encoders_i2v(self, prompt, img=None, n_prompt=None, i2v=False):
tasks = []
img_byte = self.image_transporter.prepare_image(img)
tasks.append(
asyncio.create_task(self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
)
tasks.append(
asyncio.create_task(self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
)
tasks.append(
asyncio.create_task(
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
device="cuda",
)
)
)
results = await asyncio.gather(*tasks)
# clip_encoder, vae_encoder, text_encoders
return results[0], results[1], results[2]
async def post_encoders_t2v(self, prompt, n_prompt=None):
tasks = []
tasks.append(
asyncio.create_task(
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": None, "n_prompt": n_prompt},
device="cuda",
)
)
)
results = await asyncio.gather(*tasks)
# text_encoders
return results[0]
async def run_input_encoder_server_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
async def run_input_encoder_server_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
async def run_dit_server(self, kwargs):
if self.inputs.get("image_encoder_output", None) is not None:
self.inputs["image_encoder_output"].pop("img", None)
dit_output = await self.post_task(
task_type="dit",
urls=self.config["sub_servers"]["dit"],
message={"task_id": generate_task_id(), "inputs": self.tensor_transporter.prepare_tensor(self.inputs), "kwargs": self.tensor_transporter.prepare_tensor(kwargs)},
device="cuda",
)
return dit_output, None
async def run_vae_decoder_server(self, latents, generator):
images = await self.post_task(
task_type="vae_model/decoder",
urls=self.config["sub_servers"]["vae_model"],
message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
device="cpu",
)
return images
async def run_pipeline(self): async def run_pipeline(self):
if self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer() self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.init_scheduler() self.inputs = await self.run_input_encoder()
await self.run_input_encoder() kwargs = self.set_target_shape()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) latents, generator = await self.run_dit(kwargs)
latents, generator = self.run() images = await self.run_vae_decoder(latents, generator)
self.end_run()
images = await self.run_vae(latents, generator)
self.save_video(images) self.save_video(images)
del latents, generator, images del latents, generator, images
gc.collect() gc.collect()
......
...@@ -12,7 +12,7 @@ from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel ...@@ -12,7 +12,7 @@ from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.networks.hunyuan.model import HunyuanModel from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
import torch.distributed as dist from lightx2v.utils.utils import save_videos_grid
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
...@@ -21,33 +21,24 @@ class HunyuanRunner(DefaultRunner): ...@@ -21,33 +21,24 @@ class HunyuanRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def load_transformer(self): def load_transformer(self, init_device):
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
return HunyuanModel(self.config.model_path, self.config, init_device, self.config) return HunyuanModel(self.config.model_path, self.config, init_device, self.config)
@ProfilingContext("Load models") def load_image_encoder(self, init_device):
def load_model(self): return None
if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
image_encoder = None
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
def load_text_encoder(self, init_device):
if self.config.task == "t2v": if self.config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), init_device) text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), init_device)
else: else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), init_device) text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), init_device) text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2] text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(self.config.model_path, self.config, init_device, self.config) return text_encoders
def load_vae(self, init_device):
vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=init_device, config=self.config) vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=init_device, config=self.config)
return model, text_encoders, vae_model, image_encoder return vae_model, vae_model
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config.feature_caching == "NoCaching":
...@@ -60,13 +51,13 @@ class HunyuanRunner(DefaultRunner): ...@@ -60,13 +51,13 @@ class HunyuanRunner(DefaultRunner):
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, text_encoders, config, image_encoder_output): def run_text_encoder(self, text, img):
text_encoder_output = {} text_encoder_output = {}
for i, encoder in enumerate(text_encoders): for i, encoder in enumerate(self.text_encoders):
if config.task == "i2v" and i == 0: if self.config.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], config) text_state, attention_mask = encoder.infer(text, img, self.config)
else: else:
text_state, attention_mask = encoder.infer(text, config) text_state, attention_mask = encoder.infer(text, self.config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16) text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
return text_encoder_output return text_encoder_output
...@@ -102,17 +93,19 @@ class HunyuanRunner(DefaultRunner): ...@@ -102,17 +93,19 @@ class HunyuanRunner(DefaultRunner):
wp -= 1 wp -= 1
return crop_size_list return crop_size_list
def run_image_encoder(self, config, image_encoder, vae_model): def run_image_encoder(self, img):
img = Image.open(config.image_path).convert("RGB") return None
if config.i2v_resolution == "720p": def run_vae_encoder(self, img):
kwargs = {}
if self.config.i2v_resolution == "720p":
bucket_hw_base_size = 960 bucket_hw_base_size = 960
elif config.i2v_resolution == "540p": elif self.config.i2v_resolution == "540p":
bucket_hw_base_size = 720 bucket_hw_base_size = 720
elif config.i2v_resolution == "360p": elif self.config.i2v_resolution == "360p":
bucket_hw_base_size = 480 bucket_hw_base_size = 480
else: else:
raise ValueError(f"config.i2v_resolution: {config.i2v_resolution} must be in [360p, 540p, 720p]") raise ValueError(f"self.config.i2v_resolution: {self.config.i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = img.size origin_size = img.size
...@@ -120,7 +113,8 @@ class HunyuanRunner(DefaultRunner): ...@@ -120,7 +113,8 @@ class HunyuanRunner(DefaultRunner):
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
config.target_height, config.target_width = closest_size self.config.target_height, self.config.target_width = closest_size
kwargs["target_height"], kwargs["target_width"] = closest_size
resize_param = min(closest_size) resize_param = min(closest_size)
center_crop_param = closest_size center_crop_param = closest_size
...@@ -132,12 +126,16 @@ class HunyuanRunner(DefaultRunner): ...@@ -132,12 +126,16 @@ class HunyuanRunner(DefaultRunner):
semantic_image_pixel_values = [ref_image_transform(img)] semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda")) semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
img_latents = vae_model.encode(semantic_image_pixel_values, config).mode() img_latents = self.vae_encoder.encode(semantic_image_pixel_values, self.config).mode()
scaling_factor = 0.476986 scaling_factor = 0.476986
img_latents.mul_(scaling_factor) img_latents.mul_(scaling_factor)
return {"img": img, "img_latents": img_latents} return img_latents, kwargs
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
image_encoder_output = {"img": img, "img_latents": vae_encode_out}
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
def set_target_shape(self): def set_target_shape(self):
vae_scale_factor = 2 ** (4 - 1) vae_scale_factor = 2 ** (4 - 1)
...@@ -148,3 +146,7 @@ class HunyuanRunner(DefaultRunner): ...@@ -148,3 +146,7 @@ class HunyuanRunner(DefaultRunner):
int(self.config.target_height) // vae_scale_factor, int(self.config.target_height) // vae_scale_factor,
int(self.config.target_width) // vae_scale_factor, int(self.config.target_width) // vae_scale_factor,
) )
return {"target_height": self.config.target_height, "target_width": self.config.target_width, "target_shape": self.config.target_shape}
def save_video_func(self, images):
save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))
...@@ -29,52 +29,9 @@ class WanCausVidRunner(WanRunner): ...@@ -29,52 +29,9 @@ class WanCausVidRunner(WanRunner):
self.infer_blocks = self.model.config.num_blocks self.infer_blocks = self.model.config.num_blocks
self.num_fragments = self.model.config.num_fragments self.num_fragments = self.model.config.num_fragments
def load_transformer(self): def load_transformer(self, init_device):
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
return WanCausVidModel(self.config.model_path, self.config, init_device) return WanCausVidModel(self.config.model_path, self.config, init_device)
@ProfilingContext("Load models")
def load_model(self):
if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
image_encoder = None
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device=init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
model = WanCausVidModel(self.config.model_path, self.config, init_device)
if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
if self.config.task == "i2v":
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
)
return model, text_encoders, vae_model, image_encoder
def set_inputs(self, inputs): def set_inputs(self, inputs):
super().set_inputs(inputs) super().set_inputs(inputs)
self.config["num_fragments"] = inputs.get("num_fragments", 1) self.config["num_fragments"] = inputs.get("num_fragments", 1)
......
...@@ -16,7 +16,7 @@ from lightx2v.models.networks.wan.model import WanModel ...@@ -16,7 +16,7 @@ from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
import torch.distributed as dist from lightx2v.utils.utils import cache_video
from loguru import logger from loguru import logger
...@@ -25,11 +25,7 @@ class WanRunner(DefaultRunner): ...@@ -25,11 +25,7 @@ class WanRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def load_transformer(self): def load_transformer(self, init_device):
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
model = WanModel(self.config.model_path, self.config, init_device) model = WanModel(self.config.model_path, self.config, init_device)
if self.config.lora_path: if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
...@@ -38,17 +34,21 @@ class WanRunner(DefaultRunner): ...@@ -38,17 +34,21 @@ class WanRunner(DefaultRunner):
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name}")
return model return model
@ProfilingContext("Load models") def load_image_encoder(self, init_device):
def load_model(self):
if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
image_encoder = None image_encoder = None
if self.config.cpu_offload: if self.config.task == "i2v":
init_device = torch.device("cpu") image_encoder = CLIPModel(
else: dtype=torch.float16,
init_device = torch.device("cuda") device=init_device,
checkpoint_path=os.path.join(
self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
)
return image_encoder
def load_text_encoder(self, init_device):
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
...@@ -60,47 +60,28 @@ class WanRunner(DefaultRunner): ...@@ -60,47 +60,28 @@ class WanRunner(DefaultRunner):
offload_granularity=self.config.get("text_encoder_offload_granularity", "model"), offload_granularity=self.config.get("text_encoder_offload_granularity", "model"),
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = WanModel(self.config.model_path, self.config, init_device) return text_encoders
if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
if self.config.get("tiny_vae", False): def load_vae(self, init_device):
vae_model = WanVAE_tiny( vae_config = {
"vae_pth": os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
"device": init_device,
"parallel": self.config.parallel_vae,
"use_tiling": self.config.get("use_tiling_vae", False),
}
use_tiny_decoder = self.config.get("tiny_vae", False)
is_i2v = self.config.task == "i2v"
if use_tiny_decoder:
vae_decoder = WanVAE_tiny(
vae_pth=self.config.tiny_vae_path, vae_pth=self.config.tiny_vae_path,
device=init_device, device=init_device,
) ).to("cuda")
vae_model = vae_model.to("cuda") vae_encoder = WanVAE(**vae_config) if is_i2v else None
else: else:
vae_model = WanVAE( vae_decoder = WanVAE(**vae_config)
vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), vae_encoder = vae_decoder if is_i2v else None
device=init_device,
parallel=self.config.parallel_vae,
use_tiling=self.config.get("use_tiling_vae", False),
)
if self.config.task == "i2v":
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
checkpoint_path=os.path.join(
self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
)
if self.config.get("tiny_vae", False):
org_vae = WanVAE(
vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"),
device=init_device,
parallel=self.config.parallel_vae,
use_tiling=self.config.get("use_tiling_vae", False),
)
image_encoder = [image_encoder, org_vae]
return model, text_encoders, vae_model, image_encoder return vae_encoder, vae_decoder
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config.feature_caching == "NoCaching":
...@@ -111,55 +92,60 @@ class WanRunner(DefaultRunner): ...@@ -111,55 +92,60 @@ class WanRunner(DefaultRunner):
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, text_encoders, config, image_encoder_output): def run_text_encoder(self, text, img):
text_encoder_output = {} text_encoder_output = {}
n_prompt = config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
context = text_encoders[0].infer([text]) context = self.text_encoders[0].infer([text])
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""]) context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
text_encoder_output["context"] = context text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null text_encoder_output["context_null"] = context_null
return text_encoder_output return text_encoder_output
def run_image_encoder(self, config, image_encoder, vae_model): def run_image_encoder(self, img):
if self.config.get("tiny_vae", False): img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_image_encoder, vae_image_encoder = image_encoder[0], image_encoder[1] clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
else: return clip_encoder_out
clip_image_encoder, vae_image_encoder = image_encoder, vae_model
img = Image.open(config.image_path).convert("RGB") def run_vae_encoder(self, img):
kwargs = {}
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = clip_image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = config.target_height * config.target_width max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1]) lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2]) lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
h = lat_h * config.vae_stride[1] h = lat_h * self.config.vae_stride[1]
w = lat_w * config.vae_stride[2] w = lat_w * self.config.vae_stride[2]
config.lat_h = lat_h self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h
config.lat_w = lat_w self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
msk = torch.ones(1, config.target_video_length, lat_h, lat_w, device=torch.device("cuda")) msk = torch.ones(1, self.config.target_video_length, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0 msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_image_encoder.encode( vae_encode_out = self.vae_encoder.encode(
[ [
torch.concat( torch.concat(
[ [
torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, config.target_video_length - 1, h, w), torch.zeros(3, self.config.target_video_length - 1, h, w),
], ],
dim=1, dim=1,
).cuda() ).cuda()
], ],
config, self.config,
)[0] )[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16) vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out} return vae_encode_out, kwargs
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
image_encoder_output = {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
def set_target_shape(self): def set_target_shape(self):
ret = {}
num_channels_latents = self.config.get("num_channels_latents", 16) num_channels_latents = self.config.get("num_channels_latents", 16)
if self.config.task == "i2v": if self.config.task == "i2v":
self.config.target_shape = ( self.config.target_shape = (
...@@ -168,6 +154,8 @@ class WanRunner(DefaultRunner): ...@@ -168,6 +154,8 @@ class WanRunner(DefaultRunner):
self.config.lat_h, self.config.lat_h,
self.config.lat_w, self.config.lat_w,
) )
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
elif self.config.task == "t2v": elif self.config.task == "t2v":
self.config.target_shape = ( self.config.target_shape = (
num_channels_latents, num_channels_latents,
...@@ -175,3 +163,8 @@ class WanRunner(DefaultRunner): ...@@ -175,3 +163,8 @@ class WanRunner(DefaultRunner):
int(self.config.target_height) // self.config.vae_stride[1], int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2], int(self.config.target_width) // self.config.vae_stride[2],
) )
ret["target_shape"] = self.config.target_shape
return ret
def save_video_func(self, images):
cache_video(tensor=images, save_file=self.config.save_video_path, fps=self.config.get("fps", 16), nrow=1, normalize=True, value_range=(-1, 1))
...@@ -117,7 +117,7 @@ class TensorTransporter: ...@@ -117,7 +117,7 @@ class TensorTransporter:
else: else:
return data return data
def prepare_tensor(self, data: torch.Tensor) -> bytes: def prepare_tensor(self, data) -> bytes:
self.buffer.seek(0) self.buffer.seek(0)
self.buffer.truncate() self.buffer.truncate()
torch.save(self.to_device(data, "cpu"), self.buffer) torch.save(self.to_device(data, "cpu"), self.buffer)
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.common.apis.dit \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/deploy/wan_i2v.json \
--port 9000
import requests
from loguru import logger
import random
import string
import time
from datetime import datetime
# same as lightx2v/utils/generate_task_id.py
# from lightx2v.utils.generate_task_id import generate_task_id
def generate_task_id():
"""
Generate a random task ID in the format XXXX-XXXX-XXXX-XXXX-XXXX.
Features:
1. Does not modify the global random state.
2. Each X is an uppercase letter or digit (0-9).
3. Combines time factors to ensure high randomness.
For example: N1PQ-PRM5-N1BN-Z3S1-BGBJ
"""
# Save the current random state (does not affect external randomness)
original_state = random.getstate()
try:
# Define character set (uppercase letters + digits)
characters = string.ascii_uppercase + string.digits
# Create an independent random instance
local_random = random.Random(time.perf_counter_ns())
# Generate 5 groups of 4-character random strings
groups = []
for _ in range(5):
# Mix new time factor for each group
time_mix = int(datetime.now().timestamp())
local_random.seed(time_mix + local_random.getstate()[1][0] + time.perf_counter_ns())
groups.append("".join(local_random.choices(characters, k=4)))
return "-".join(groups)
finally:
# Restore the original random state
random.setstate(original_state)
if __name__ == "__main__":
url = "http://localhost:8000/v1/local/video/generate"
message = {
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "./assets/inputs/imgs/img_0.jpg",
"save_video_path": "./output_lightx2v_wan_i2v_t02.mp4", # It is best to set it to an absolute path.
}
logger.info(f"message: {message}")
response = requests.post(url, json=message)
logger.info(f"response: {response.json()}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment