Commit d71f936d authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

Remove vae args (#250)

parent cf04772a
......@@ -18,7 +18,6 @@
"adaptive_resize": true,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses",
"use_patch_vae": false
"seq_p_attn_type": "ulysses"
}
}
import argparse
import json
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
from lightx2v.utils.set_config import set_config
tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
inputs: bytes
kwargs: bytes
def get(self, key, default=None):
return getattr(self, key, default)
class DiTServiceStatus(BaseServiceStatus):
pass
class DiTRunner:
def __init__(self, config):
self.config = config
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.model = self.runner.load_transformer()
def _run_dit(self, inputs, kwargs):
self.runner.config.update(tensor_transporter.load_tensor(kwargs))
self.runner.inputs = tensor_transporter.load_tensor(inputs)
self.runner.init_scheduler()
self.runner.model.scheduler.prepare(self.runner.inputs["image_encoder_output"])
latents, _ = self.runner.run()
self.runner.end_run()
return latents
def run_dit(message: Message):
try:
global runner
dit_output = runner._run_dit(message.inputs, message.kwargs)
DiTServiceStatus.complete_task(message)
return dit_output
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
DiTServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/dit/generate")
def v1_local_dit_generate(message: Message):
try:
task_id = DiTServiceStatus.start_task(message)
dit_output = run_dit(message)
output = tensor_transporter.prepare_tensor(dit_output)
del dit_output
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/dit/generate/service_status")
async def get_service_status():
return DiTServiceStatus.get_status_service()
@app.get("/v1/local/dit/generate/get_all_tasks")
async def get_all_tasks():
return DiTServiceStatus.get_all_tasks()
@app.post("/v1/local/dit/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return DiTServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=9000)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = DiTRunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
import argparse
import json
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
from lightx2v.utils.set_config import set_config
tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
img: bytes
def get(self, key, default=None):
return getattr(self, key, default)
class ImageEncoderServiceStatus(BaseServiceStatus):
pass
class ImageEncoderRunner:
def __init__(self, config):
self.config = config
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.image_encoder = self.runner.load_image_encoder()
def _run_image_encoder(self, img):
img = image_transporter.load_image(img)
return self.runner.run_image_encoder(img)
def run_image_encoder(message: Message):
try:
global runner
image_encoder_out = runner._run_image_encoder(message.img)
ImageEncoderServiceStatus.complete_task(message)
return image_encoder_out
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
ImageEncoderServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/image_encoder/generate")
def v1_local_image_encoder_generate(message: Message):
try:
task_id = ImageEncoderServiceStatus.start_task(message)
image_encoder_output = run_image_encoder(message)
output = tensor_transporter.prepare_tensor(image_encoder_output)
del image_encoder_output
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/image_encoder/generate/service_status")
async def get_service_status():
return ImageEncoderServiceStatus.get_status_service()
@app.get("/v1/local/image_encoder/generate/get_all_tasks")
async def get_all_tasks():
return ImageEncoderServiceStatus.get_all_tasks()
@app.post("/v1/local/image_encoder/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return ImageEncoderServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=9003)
args = parser.parse_args()
logger.info(f"args: {args}")
assert args.task == "i2v"
with ProfilingContext("Init Server Cost"):
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = ImageEncoderRunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
import argparse
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from vllm import LLM, SamplingParams
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.service_utils import BaseServiceStatus, ProcessManager, TaskStatusMessage
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
sys_prompt = """
Transform the short prompt into a detailed video-generation caption using this structure:
​​Opening shot type​​ (long/medium/close-up/extreme close-up/full shot)
​​Primary subject(s)​​ with vivid attributes (colors, textures, actions, interactions)
​​Dynamic elements​​ (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
​​Scene composition​​ (background, environment, spatial relationships)
​​Lighting/atmosphere​​ (natural/artificial, time of day, mood)
​​Camera motion​​ (zooms, pans, static/handheld shots) if applicable.
Pattern Summary from Examples:
[Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
​One case:
Short prompt: a person is playing football
Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
​​Now expand this short prompt:​​ [{}]. Please only output the final long prompt in English.
"""
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
prompt: str
def get(self, key, default=None):
return getattr(self, key, default)
class PromptEnhancerServiceStatus(BaseServiceStatus):
pass
class PromptEnhancerRunner:
def __init__(self, model_path):
self.model_path = model_path
self.model = self.get_model()
self.sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=8192,
)
def get_model(self):
model = LLM(model=self.model_path, trust_remote_code=True, dtype="bfloat16", gpu_memory_utilization=0.95, max_model_len=16384)
return model
def _run_prompt_enhancer(self, prompt):
prompt = prompt.strip()
prompt = sys_prompt.format(prompt)
messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
outputs = self.model.chat(
messages=messages,
sampling_params=self.sampling_params,
)
enhanced_prompt = outputs[0].outputs[0].text
return enhanced_prompt.strip()
def run_prompt_enhancer(message: Message):
try:
global runner
enhanced_prompt = runner._run_prompt_enhancer(message.prompt)
assert enhanced_prompt is not None
PromptEnhancerServiceStatus.complete_task(message)
return enhanced_prompt
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
PromptEnhancerServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/prompt_enhancer/generate")
def v1_local_prompt_enhancer_generate(message: Message):
try:
task_id = PromptEnhancerServiceStatus.start_task(message)
enhanced_prompt = run_prompt_enhancer(message)
return {"task_id": task_id, "task_status": "completed", "output": enhanced_prompt, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/prompt_enhancer/generate/service_status")
async def get_service_status():
return PromptEnhancerServiceStatus.get_status_service()
@app.get("/v1/local/prompt_enhancer/generate/get_all_tasks")
async def get_all_tasks():
return PromptEnhancerServiceStatus.get_all_tasks()
@app.post("/v1/local/prompt_enhancer/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return PromptEnhancerServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--port", type=int, default=9001)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
runner = PromptEnhancerRunner(args.model_path)
uvicorn.run(app, host="0.0.0.0", port=args.port, reload=False, workers=1)
import argparse
import json
from typing import Optional
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
from lightx2v.utils.set_config import set_config
tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
text: str
img: Optional[bytes] = None
n_prompt: Optional[str] = None
def get(self, key, default=None):
return getattr(self, key, default)
class TextEncoderServiceStatus(BaseServiceStatus):
pass
class TextEncoderRunner:
def __init__(self, config):
self.config = config
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.text_encoders = self.runner.load_text_encoder()
def _run_text_encoder(self, text, img, n_prompt):
if img is not None:
img = image_transporter.load_image(img)
self.runner.config["negative_prompt"] = n_prompt
text_encoder_output = self.runner.run_text_encoder(text, img)
return text_encoder_output
def run_text_encoder(message: Message):
try:
global runner
text_encoder_output = runner._run_text_encoder(message.text, message.img, message.n_prompt)
TextEncoderServiceStatus.complete_task(message)
return text_encoder_output
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
TextEncoderServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/text_encoders/generate")
def v1_local_text_encoder_generate(message: Message):
try:
task_id = TextEncoderServiceStatus.start_task(message)
text_encoder_output = run_text_encoder(message)
output = tensor_transporter.prepare_tensor(text_encoder_output)
del text_encoder_output
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/text_encoders/generate/service_status")
async def get_service_status():
return TextEncoderServiceStatus.get_status_service()
@app.get("/v1/local/text_encoders/generate/get_all_tasks")
async def get_all_tasks():
return TextEncoderServiceStatus.get_all_tasks()
@app.post("/v1/local/text_encoders/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return TextEncoderServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=9002)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = TextEncoderRunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
import argparse
import json
from typing import Optional
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel
from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.service_utils import BaseServiceStatus, ImageTransporter, ProcessManager, TaskStatusMessage, TensorTransporter
from lightx2v.utils.set_config import set_config
tensor_transporter = TensorTransporter()
image_transporter = ImageTransporter()
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
img: Optional[bytes] = None
latents: Optional[bytes] = None
def get(self, key, default=None):
return getattr(self, key, default)
class VAEServiceStatus(BaseServiceStatus):
pass
class VAERunner:
def __init__(self, config):
self.config = config
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.vae_encoder, self.runner.vae_decoder = self.runner.load_vae()
def _run_vae_encoder(self, img):
img = image_transporter.load_image(img)
vae_encoder_out, kwargs = self.runner.run_vae_encoder(img)
return vae_encoder_out, kwargs
def _run_vae_decoder(self, latents):
latents = tensor_transporter.load_tensor(latents)
images = self.runner.vae_decoder.decode(latents, generator=None, config=self.config)
return images
def run_vae_encoder(message: Message):
try:
global runner
vae_encoder_out, kwargs = runner._run_vae_encoder(message.img)
VAEServiceStatus.complete_task(message)
return vae_encoder_out, kwargs
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
VAEServiceStatus.record_failed_task(message, error=str(e))
def run_vae_decoder(message: Message):
try:
global runner
images = runner._run_vae_decoder(message.latents)
VAEServiceStatus.complete_task(message)
return images
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
VAEServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/vae_model/encoder/generate")
def v1_local_vae_model_encoder_generate(message: Message):
try:
task_id = VAEServiceStatus.start_task(message)
vae_encoder_out, kwargs = run_vae_encoder(message)
output = tensor_transporter.prepare_tensor(vae_encoder_out)
del vae_encoder_out
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": kwargs}
except RuntimeError as e:
return {"error": str(e)}
@app.post("/v1/local/vae_model/decoder/generate")
def v1_local_vae_model_decoder_generate(message: Message):
try:
task_id = VAEServiceStatus.start_task(message)
vae_decode_out = run_vae_decoder(message)
output = tensor_transporter.prepare_tensor(vae_decode_out)
del vae_decode_out
return {"task_id": task_id, "task_status": "completed", "output": output, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/vae_model/generate/service_status")
async def get_service_status():
return VAEServiceStatus.get_status_service()
@app.get("/v1/local/vae_model/encoder/generate/service_status")
async def get_service_status():
return VAEServiceStatus.get_status_service()
@app.get("/v1/local/vae_model/decoder/generate/service_status")
async def get_service_status():
return VAEServiceStatus.get_status_service()
@app.get("/v1/local/vae_model/encoder/generate/get_all_tasks")
async def get_all_tasks():
return VAEServiceStatus.get_all_tasks()
@app.get("/v1/local/vae_model/decoder/generate/get_all_tasks")
async def get_all_tasks():
return VAEServiceStatus.get_all_tasks()
@app.post("/v1/local/vae_model/encoder/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return VAEServiceStatus.get_status_task_id(message.task_id)
@app.post("/v1/local/vae_model/decoder/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return VAEServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--port", type=int, default=9004)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = VAERunner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
......@@ -231,7 +231,7 @@ class DefaultRunner(BaseRunner):
def run_vae_decoder(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
images = self.vae_decoder.decode(latents)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
......
......@@ -11,7 +11,7 @@ from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching, HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.video_encoders.hf.hunyuan.hunyuan_vae import HunyuanVAE
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_videos_grid
......@@ -38,7 +38,7 @@ class HunyuanRunner(DefaultRunner):
return text_encoders
def load_vae(self):
vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=self.init_device, config=self.config)
vae_model = HunyuanVAE(self.config.model_path, dtype=torch.float16, device=self.init_device, config=self.config)
return vae_model, vae_model
def init_scheduler(self):
......
......@@ -159,7 +159,7 @@ class WanVaceRunner(WanRunner):
self.config.target_shape = target_shape
@ProfilingContext("Run VAE Decoder")
def run_vae_decoder(self, latents, generator):
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
......@@ -169,7 +169,7 @@ class WanVaceRunner(WanRunner):
if refs is not None:
latents = latents[:, len(refs) :, :, :]
images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
images = self.vae_decoder.decode(latents)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
......
......@@ -302,7 +302,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images/videos.
......
......@@ -5,7 +5,7 @@ import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDistribution
class VideoEncoderKLCausal3DModel:
class HunyuanVAE:
def __init__(self, model_path, dtype, device, config):
self.model_path = model_path
self.dtype = dtype
......@@ -32,20 +32,20 @@ class VideoEncoderKLCausal3DModel:
def to_cuda(self):
self.model = self.model.to("cuda")
def decode(self, latents, generator, config):
if config.cpu_offload:
def decode(self, latents):
if self.config.cpu_offload:
self.to_cuda()
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
self.model.enable_tiling()
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = self.model.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
if config.cpu_offload:
if self.config.cpu_offload:
self.to_cpu()
return image
def encode(self, x, config):
def encode(self, x):
h = self.model.encoder(x)
moments = self.model.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
......@@ -54,4 +54,4 @@ class VideoEncoderKLCausal3DModel:
if __name__ == "__main__":
model_path = ""
vae_model = VideoEncoderKLCausal3DModel(model_path, dtype=torch.float16, device=torch.device("cuda"))
vae_model = HunyuanVAE(model_path, dtype=torch.float16, device=torch.device("cuda"))
......@@ -957,7 +957,7 @@ class WanVAE:
return images
def decode(self, zs, **args):
def decode(self, zs):
if self.cpu_offload:
self.to_cuda()
......
......@@ -993,7 +993,7 @@ class Wan2_2_VAE:
self.to_cpu()
return out
def decode(self, zs, **args):
def decode(self, zs):
if self.cpu_offload:
self.to_cuda()
images = self.model.decode(zs.unsqueeze(0), self.scale, offload_cache=self.offload_cache if self.cpu_offload else False).float().clamp_(-1, 1)
......
......@@ -22,7 +22,7 @@ class WanVAE_tiny(nn.Module):
@peak_memory_decorator
@torch.no_grad()
def decode(self, latents, generator=None, return_dict=None, config=None):
def decode(self, latents):
latents = latents.unsqueeze(0)
n, c, t, h, w = latents.shape
# low-memory, set parallel=True for faster + higher memory
......
import os
import torch
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from lightx2v.models.video_encoders.trt.autoencoder_kl_causal_3d import trt_vae_infer
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device):
self.model_path = model_path
self.dtype = dtype
self.device = device
self.load()
def load(self):
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False)
self.model.eval()
trt_decoder = trt_vae_infer.HyVaeTrtModelInfer(engine_path=os.path.join(self.vae_path, "vae_decoder.engine"))
self.model.decoder = trt_decoder
def decode(self, latents, generator):
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=self.device)
self.model.enable_tiling()
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
return image
if __name__ == "__main__":
model_path = ""
vae_model = VideoEncoderKLCausal3DModel(model_path, dtype=torch.float16, device=torch.device("cuda"))
import os
from pathlib import Path
from subprocess import Popen
import numpy as np
import tensorrt as trt
import torch
import torch.nn as nn
from cuda import cudart
from loguru import logger
from lightx2v.common.backend_infer.trt import common
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
class HyVaeTrtModelInfer(nn.Module):
"""
Implements inference for the TensorRT engine.
"""
def __init__(self, engine_path):
"""
:param engine_path: The path to the serialized engine to load from disk.
"""
# Load TRT engine
if not Path(engine_path).exists():
# dir_name = str(Path(engine_path).parents)
# onnx_path = self.export_to_onnx(decoder, dir_name)
# self.convert_to_trt_engine(onnx_path, engine_path)
raise FileNotFoundError(f"VAE tensorrt engine `{str(engine_path)}` not exists.")
self.logger = trt.Logger(trt.Logger.ERROR)
with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
assert runtime
self.engine = runtime.deserialize_cuda_engine(f.read())
assert self.engine
self.context = self.engine.create_execution_context()
assert self.context
logger.info(f"Loaded VAE tensorrt engine from `{engine_path}`")
def alloc(self, shape_dict):
"""
Setup I/O bindings
"""
self.inputs = []
self.outputs = []
self.allocations = []
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
is_input = False
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
is_input = True
dtype = self.engine.get_tensor_dtype(name)
# shape = self.engine.get_tensor_shape(name)
shape = shape_dict[name]
if is_input:
self.context.set_input_shape(name, shape)
self.batch_size = shape[0]
size = np.dtype(trt.nptype(dtype)).itemsize
for s in shape:
size *= s
allocation = common.cuda_call(cudart.cudaMalloc(size))
binding = {
"index": i,
"name": name,
"dtype": np.dtype(trt.nptype(dtype)),
"shape": list(shape),
"allocation": allocation,
}
self.allocations.append(allocation)
if is_input:
self.inputs.append(binding)
else:
self.outputs.append(binding)
assert self.batch_size > 0
assert len(self.inputs) > 0
assert len(self.outputs) > 0
assert len(self.allocations) > 0
def input_spec(self):
"""
Get the specs for the input tensor of the network. Useful to prepare memory allocations.
:return: Two items, the shape of the input tensor and its (numpy) datatype.
"""
return self.inputs[0]["shape"], self.inputs[0]["dtype"]
def output_spec(self):
"""
Get the specs for the output tensor of the network. Useful to prepare memory allocations.
:return: Two items, the shape of the output tensor and its (numpy) datatype.
"""
return self.outputs[0]["shape"], self.outputs[0]["dtype"]
def __call__(self, batch, top=1):
"""
Execute inference
"""
# Prepare the output data
device = batch.device
dtype = batch.dtype
batch = batch.cpu().numpy()
def get_output_shape(shp):
b, c, t, h, w = shp
out = (b, 3, 4 * (t - 1) + 1, h * 8, w * 8)
return out
shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)}
self.alloc(shp_dict)
output = np.zeros(*self.output_spec())
# Process I/O and execute the network
common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch))
self.context.execute_v2(self.allocations)
common.memcpy_device_to_host(output, self.outputs[0]["allocation"])
output = torch.from_numpy(output).to(device).type(dtype)
return output
@staticmethod
def export_to_onnx(decoder: torch.nn.Module, model_dir):
logger.info("Start to do VAE onnx exporting.")
device = next(decoder.parameters())[0].device
example_inp = torch.rand(1, 16, 17, 32, 32).to(device).type(next(decoder.parameters())[0].dtype)
out_path = str(Path(str(model_dir)) / "vae_decoder.onnx")
torch.onnx.export(
decoder.eval().half(),
example_inp.half(),
out_path,
input_names=["inp"],
output_names=["out"],
opset_version=14,
dynamic_axes={"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, "out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}},
)
# onnx_ori = onnx.load(out_path)
os.system(f"onnxsim {out_path} {out_path}")
# onnx_opt, check = simplify(onnx_ori)
# assert check, f"Simplified ONNX model({out_path}) could not be validated."
# onnx.save(onnx_opt, out_path)
logger.info("Finish VAE onnx exporting.")
return out_path
@staticmethod
def convert_to_trt_engine(onnx_path, engine_path):
logger.info("Start to convert VAE ONNX to tensorrt engine.")
cmd = (
"trtexec "
f"--onnx={onnx_path} "
f"--saveEngine={engine_path} "
"--allowWeightStreaming "
"--stronglyTyped "
"--fp16 "
"--weightStreamingBudget=100 "
"--minShapes=inp:1x16x9x18x16 "
"--optShapes=inp:1x16x17x32x16 "
"--maxShapes=inp:1x16x17x32x32 "
)
p = Popen(cmd, shell=True)
p.wait()
if not Path(engine_path).exists():
raise RuntimeError(f"Convert vae onnx({onnx_path}) to tensorrt engine failed.")
logger.info("Finish VAE tensorrt converting.")
return engine_path
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment