Commit 420fec7f authored by helloyongyang's avatar helloyongyang
Browse files

Support save_naive_quant and load quantization weight

parent a81ad1e5
import os
import sys
import torch
import time
import glob
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
......@@ -16,6 +16,8 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
from loguru import logger
class WanModel:
......@@ -29,6 +31,11 @@ class WanModel:
self.device = device
self._init_infer_class()
self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
self.save_weights(self.config.naive_quant_path)
sys.exit(0)
self._init_infer()
self.current_lora = None
......@@ -74,9 +81,19 @@ class WanModel:
weight_dict.update(file_weights)
return weight_dict
def _load_ckpt_quant_model(self):
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None"
logger.info(f"Loading quant model from {self.config.naive_quant_path}")
quant_weights_path = os.path.join(self.config.naive_quant_path, "quant_weights.pth")
weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True)
return weight_dict
def _init_weights(self, weight_dict=None):
if weight_dict is None:
self.original_weight_dict = self._load_ckpt()
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False):
self.original_weight_dict = self._load_ckpt()
else:
self.original_weight_dict = self._load_ckpt_quant_model()
else:
self.original_weight_dict = weight_dict
# init weights
......@@ -93,6 +110,28 @@ class WanModel:
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
def save_weights(self, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
pre_state_dict = self.pre_weight.state_dict()
logger.info(pre_state_dict.keys())
post_state_dict = self.post_weight.state_dict()
logger.info(post_state_dict.keys())
transformer_state_dict = self.transformer_weights.state_dict()
logger.info(transformer_state_dict.keys())
save_dict = {}
save_dict.update(pre_state_dict)
save_dict.update(post_state_dict)
save_dict.update(transformer_state_dict)
save_path = os.path.join(save_path, "quant_weights.pth")
torch.save(save_dict, save_path)
logger.info(f"Save weights to {save_path}")
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
......
......@@ -12,3 +12,9 @@ def CHECK_ENABLE_PROFILING_DEBUG():
def CHECK_ENABLE_GRAPH_MODE():
ENABLE_GRAPH_MODE = os.getenv("ENABLE_GRAPH_MODE", "false").lower() == "true"
return ENABLE_GRAPH_MODE
@lru_cache(maxsize=None)
def GET_RUNNING_FLAG():
RUNNING_FLAG = os.getenv("RUNNING_FLAG", "infer")
return RUNNING_FLAG
......@@ -18,7 +18,7 @@ def get_default_config():
"use_bfloat16": True,
"lora_path": None,
"strength_model": 1.0,
"mm_config": None,
"mm_config": {},
}
return default_config
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls hunyuan \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_i2v_save_quant.json \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_i2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls hunyuan \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_i2v_save_quant.json \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_i2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_t2v_save_quant.json \
--prompt "A cat walks on the grass, realistic style." \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_t2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_t2v_save_quant.json \
--prompt "A cat walks on the grass, realistic style." \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_hy_t2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_save_quant.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_save_quant.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=2
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
# =========================
# save quantization weight
# =========================
export RUNNING_FLAG=save_naive_quant
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
sleep 2
# =========================
# load quantization weight and inference
# =========================
export RUNNING_FLAG=infer
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment