Unverified Commit 3aab9893 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

update vae (#410)

parent b32321e4
......@@ -4,11 +4,13 @@ Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
"""
import os
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from tqdm.auto import tqdm
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
......@@ -226,7 +228,16 @@ class TAEHV(nn.Module):
conv(n_f[3], self.image_channels * self.patch_size**2),
)
if checkpoint_path is not None:
self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)))
ext = os.path.splitext(checkpoint_path)[1].lower()
if ext == ".pth":
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
elif ext == ".safetensors":
state_dict = load_file(checkpoint_path, device="cpu")
else:
raise ValueError(f"Unsupported checkpoint format: {ext}. Supported formats: .pth, .safetensors")
self.load_state_dict(self.patch_tgrow_layers(state_dict))
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed.
......
"""
LightX2V Setup Script
Minimal installation for VAE models only
"""
import os
from setuptools import find_packages, setup
# Read the README file
def read_readme():
readme_path = os.path.join(os.path.dirname(__file__), "README.md")
if os.path.exists(readme_path):
with open(readme_path, "r", encoding="utf-8") as f:
return f.read()
return ""
# Core dependencies for VAE models
vae_dependencies = [
"torch>=2.0.0",
"numpy>=1.20.0",
"einops>=0.6.0",
"loguru>=0.6.0",
]
# Full dependencies for complete LightX2V
full_dependencies = [
"packaging",
"ninja",
"torch",
"torchvision",
"diffusers",
"transformers",
"tokenizers",
"tqdm",
"accelerate",
"safetensors",
"opencv-python",
"numpy",
"imageio",
"imageio-ffmpeg",
"einops",
"loguru",
"ftfy",
"gradio",
"aiohttp",
"pydantic",
"fastapi",
"uvicorn",
"requests",
"decord",
]
setup(
name="lightx2v",
version="1.0.0",
author="LightX2V Team",
author_email="",
description="LightX2V: High-performance video generation models with optimized VAE",
long_description=read_readme(),
long_description_content_type="text/markdown",
url="https://github.com/ModelTC/LightX2V",
packages=find_packages(include=["lightx2v", "lightx2v.*"]),
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
python_requires=">=3.8",
install_requires=vae_dependencies,
extras_require={
"full": full_dependencies,
"vae": vae_dependencies,
},
include_package_data=True,
zip_safe=False,
)
......@@ -390,7 +390,7 @@ def quantize_model(
# Check if key matches target modules
parts = key.split(".")
if comfyui_mode and key in comfyui_keys:
if comfyui_mode and (comfyui_keys is not None and key in comfyui_keys):
pass
elif len(parts) < key_idx + 1 or parts[key_idx] not in target_keys:
if adapter_keys is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment