"...resnet50_tensorflow.git" did not exist on "27fb855b027ead16d2616dcb59c67409a2176b7f"
Unverified Commit 3aab9893 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

update vae (#410)

parent b32321e4
...@@ -4,11 +4,13 @@ Tiny AutoEncoder for Hunyuan Video ...@@ -4,11 +4,13 @@ Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space) (DNN for encoding / decoding videos to Hunyuan Video's latent space)
""" """
import os
from collections import namedtuple from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from safetensors.torch import load_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
...@@ -226,7 +228,16 @@ class TAEHV(nn.Module): ...@@ -226,7 +228,16 @@ class TAEHV(nn.Module):
conv(n_f[3], self.image_channels * self.patch_size**2), conv(n_f[3], self.image_channels * self.patch_size**2),
) )
if checkpoint_path is not None: if checkpoint_path is not None:
self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True))) ext = os.path.splitext(checkpoint_path)[1].lower()
if ext == ".pth":
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
elif ext == ".safetensors":
state_dict = load_file(checkpoint_path, device="cpu")
else:
raise ValueError(f"Unsupported checkpoint format: {ext}. Supported formats: .pth, .safetensors")
self.load_state_dict(self.patch_tgrow_layers(state_dict))
def patch_tgrow_layers(self, sd): def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed. """Patch TGrow layers to use a smaller kernel if needed.
......
"""
LightX2V Setup Script
Minimal installation for VAE models only
"""
import os
from setuptools import find_packages, setup
# Read the README file
def read_readme():
readme_path = os.path.join(os.path.dirname(__file__), "README.md")
if os.path.exists(readme_path):
with open(readme_path, "r", encoding="utf-8") as f:
return f.read()
return ""
# Core dependencies for VAE models
vae_dependencies = [
"torch>=2.0.0",
"numpy>=1.20.0",
"einops>=0.6.0",
"loguru>=0.6.0",
]
# Full dependencies for complete LightX2V
full_dependencies = [
"packaging",
"ninja",
"torch",
"torchvision",
"diffusers",
"transformers",
"tokenizers",
"tqdm",
"accelerate",
"safetensors",
"opencv-python",
"numpy",
"imageio",
"imageio-ffmpeg",
"einops",
"loguru",
"ftfy",
"gradio",
"aiohttp",
"pydantic",
"fastapi",
"uvicorn",
"requests",
"decord",
]
setup(
name="lightx2v",
version="1.0.0",
author="LightX2V Team",
author_email="",
description="LightX2V: High-performance video generation models with optimized VAE",
long_description=read_readme(),
long_description_content_type="text/markdown",
url="https://github.com/ModelTC/LightX2V",
packages=find_packages(include=["lightx2v", "lightx2v.*"]),
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
python_requires=">=3.8",
install_requires=vae_dependencies,
extras_require={
"full": full_dependencies,
"vae": vae_dependencies,
},
include_package_data=True,
zip_safe=False,
)
...@@ -390,7 +390,7 @@ def quantize_model( ...@@ -390,7 +390,7 @@ def quantize_model(
# Check if key matches target modules # Check if key matches target modules
parts = key.split(".") parts = key.split(".")
if comfyui_mode and key in comfyui_keys: if comfyui_mode and (comfyui_keys is not None and key in comfyui_keys):
pass pass
elif len(parts) < key_idx + 1 or parts[key_idx] not in target_keys: elif len(parts) < key_idx + 1 or parts[key_idx] not in target_keys:
if adapter_keys is None: if adapter_keys is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment