Commit 6bd320af authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #70 from ModelTC/dev_FIX

Fixed the accuracy fluctuation bug
parents e9e33065 774ccfe7
......@@ -25,7 +25,7 @@ fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
......@@ -36,5 +36,5 @@ python -m lightx2v.infer \
--config_json ${lightx2v_path}/configs/wan_t2v_enhancer.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--use_prompt_enhancer \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
......@@ -25,7 +25,7 @@ fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
......@@ -35,5 +35,5 @@ python -m lightx2v.infer \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_sparge.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
......@@ -24,7 +24,7 @@ fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
......@@ -34,5 +34,5 @@ python -m lightx2v.infer \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/caching/wan_t2v_Tea.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_tea.mp4
......@@ -26,6 +26,7 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
echo "=========================================="
echo "启动分布式推理API服务器"
......
......@@ -31,6 +31,7 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
# Start multiple servers
python -m lightx2v.api_multi_servers \
......
......@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
python -m lightx2v.api_server \
--model_cls wan2.1 \
......
File mode changed from 100644 to 100755
......@@ -339,7 +339,6 @@ def quantize_model(
weights,
w_bit=8,
target_keys=["attn", "ffn"],
min_params=1e6,
key_idx=2,
ignore_key=None,
dtype=torch.int8,
......@@ -351,7 +350,6 @@ def quantize_model(
weights: Model state dictionary
w_bit: Quantization bit width
target_keys: List of module names to quantize
min_params: Minimum parameter count to process tensor
Returns:
Modified state dictionary with quantized weights and scales
......@@ -371,7 +369,7 @@ def quantize_model(
tensor = weights[key]
# Skip non-tensors, small tensors, and non-2D tensors
if not isinstance(tensor, torch.Tensor) or tensor.numel() < min_params or tensor.dim() != 2:
if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2:
continue
# Check if key matches target modules
......@@ -442,7 +440,6 @@ def convert_weights(args):
converted_weights,
w_bit=args.bits,
target_keys=args.target_keys,
min_params=args.min_params,
key_idx=args.key_idx,
ignore_key=args.ignore_key,
dtype=args.dtype,
......@@ -522,7 +519,7 @@ def convert_weights(args):
def copy_non_weight_files(source_dir, target_dir):
ignore_extensions = [".pth", ".pt", ".safetensors"]
ignore_extensions = [".pth", ".pt", ".safetensors", ".index.json"]
logger.info(f"Start copying non-weighted files and subdirectories...")
......@@ -575,12 +572,6 @@ def main():
# Quantization
parser.add_argument("--quantized", action="store_true")
parser.add_argument("--bits", type=int, default=8, choices=[8], help="Quantization bit width")
parser.add_argument(
"--min_params",
type=int,
default=1000000,
help="Minimum parameters to consider for quantization",
)
parser.add_argument(
"--device",
type=str,
......@@ -595,47 +586,48 @@ def main():
)
args = parser.parse_args()
if args.dtype == "torch.int8":
args.dtype = torch.int8
elif args.dtype == "torch.float8_e4m3fn":
args.dtype = torch.float8_e4m3fn
else:
raise ValueError(f"Not support dtype :{args.dtype}")
model_type_keys_map = {
"wan_dit": {
"key_idx": 2,
"target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": None,
},
"hunyuan_dit": {
"key_idx": 2,
"target_keys": [
"img_mod",
"img_attn_qkv",
"img_attn_proj",
"img_mlp",
"txt_mod",
"txt_attn_qkv",
"txt_attn_proj",
"txt_mlp",
"linear1",
"linear2",
"modulation",
],
"ignore_key": None,
},
"wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None},
"wan_clip": {
"key_idx": 3,
"target_keys": ["attn", "mlp"],
"ignore_key": "textual",
},
}
args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
if args.quantized:
if args.dtype == "torch.int8":
args.dtype = torch.int8
elif args.dtype == "torch.float8_e4m3fn":
args.dtype = torch.float8_e4m3fn
else:
raise ValueError(f"Not support dtype :{args.dtype}")
model_type_keys_map = {
"wan_dit": {
"key_idx": 2,
"target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": None,
},
"hunyuan_dit": {
"key_idx": 2,
"target_keys": [
"img_mod",
"img_attn_qkv",
"img_attn_proj",
"img_mlp",
"txt_mod",
"txt_attn_qkv",
"txt_attn_proj",
"txt_mlp",
"linear1",
"linear2",
"modulation",
],
"ignore_key": None,
},
"wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None},
"wan_clip": {
"key_idx": 3,
"target_keys": ["attn", "mlp"],
"ignore_key": "textual",
},
}
args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file")
......
......@@ -33,7 +33,7 @@ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit
......@@ -44,7 +44,7 @@ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name wan_fp8 \
--dtype torch.float8_e4m3fn \
--model_type wan_dit
......@@ -57,7 +57,7 @@ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name hunyuan_int8 \
--dtype torch.int8 \
--model_type hunyuan_dit
......@@ -68,7 +68,7 @@ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \
--model_type hunyuan_dit
......
......@@ -33,7 +33,7 @@ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit
......@@ -44,7 +44,7 @@ python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name wan_fp8 \
--dtype torch.float8_e4m3fn \
--model_type wan_dit
......@@ -57,7 +57,7 @@ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext ..safetensors \
--output_name hunyuan_int8 \
--dtype torch.int8 \
--model_type hunyuan_dit
......@@ -68,7 +68,7 @@ python converter.py \
--quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext .pth\
--output_ext .safetensors \
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \
--model_type hunyuan_dit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment