"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "ecc784082600b830649f2a4c0b2d8d7e9d8755ed"
Commit 6bd320af authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #70 from ModelTC/dev_FIX

Fixed the accuracy fluctuation bug
parents e9e33065 774ccfe7
...@@ -25,7 +25,7 @@ fi ...@@ -25,7 +25,7 @@ fi
export TOKENIZERS_PARALLELISM=false export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
...@@ -36,5 +36,5 @@ python -m lightx2v.infer \ ...@@ -36,5 +36,5 @@ python -m lightx2v.infer \
--config_json ${lightx2v_path}/configs/wan_t2v_enhancer.json \ --config_json ${lightx2v_path}/configs/wan_t2v_enhancer.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--use_prompt_enhancer \ --use_prompt_enhancer \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ --negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 --save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
...@@ -25,7 +25,7 @@ fi ...@@ -25,7 +25,7 @@ fi
export TOKENIZERS_PARALLELISM=false export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
...@@ -35,5 +35,5 @@ python -m lightx2v.infer \ ...@@ -35,5 +35,5 @@ python -m lightx2v.infer \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_sparge.json \ --config_json ${lightx2v_path}/configs/wan_t2v_sparge.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ --negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 --save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
...@@ -24,7 +24,7 @@ fi ...@@ -24,7 +24,7 @@ fi
export TOKENIZERS_PARALLELISM=false export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
...@@ -34,5 +34,5 @@ python -m lightx2v.infer \ ...@@ -34,5 +34,5 @@ python -m lightx2v.infer \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/caching/wan_t2v_Tea.json \ --config_json ${lightx2v_path}/configs/caching/wan_t2v_Tea.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ --negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_tea.mp4 --save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_tea.mp4
...@@ -26,6 +26,7 @@ export TOKENIZERS_PARALLELISM=false ...@@ -26,6 +26,7 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
echo "==========================================" echo "=========================================="
echo "启动分布式推理API服务器" echo "启动分布式推理API服务器"
......
...@@ -31,6 +31,7 @@ export TOKENIZERS_PARALLELISM=false ...@@ -31,6 +31,7 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
# Start multiple servers # Start multiple servers
python -m lightx2v.api_multi_servers \ python -m lightx2v.api_multi_servers \
......
...@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
python -m lightx2v.api_server \ python -m lightx2v.api_server \
--model_cls wan2.1 \ --model_cls wan2.1 \
......
File mode changed from 100644 to 100755
...@@ -339,7 +339,6 @@ def quantize_model( ...@@ -339,7 +339,6 @@ def quantize_model(
weights, weights,
w_bit=8, w_bit=8,
target_keys=["attn", "ffn"], target_keys=["attn", "ffn"],
min_params=1e6,
key_idx=2, key_idx=2,
ignore_key=None, ignore_key=None,
dtype=torch.int8, dtype=torch.int8,
...@@ -351,7 +350,6 @@ def quantize_model( ...@@ -351,7 +350,6 @@ def quantize_model(
weights: Model state dictionary weights: Model state dictionary
w_bit: Quantization bit width w_bit: Quantization bit width
target_keys: List of module names to quantize target_keys: List of module names to quantize
min_params: Minimum parameter count to process tensor
Returns: Returns:
Modified state dictionary with quantized weights and scales Modified state dictionary with quantized weights and scales
...@@ -371,7 +369,7 @@ def quantize_model( ...@@ -371,7 +369,7 @@ def quantize_model(
tensor = weights[key] tensor = weights[key]
# Skip non-tensors, small tensors, and non-2D tensors # Skip non-tensors, small tensors, and non-2D tensors
if not isinstance(tensor, torch.Tensor) or tensor.numel() < min_params or tensor.dim() != 2: if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2:
continue continue
# Check if key matches target modules # Check if key matches target modules
...@@ -442,7 +440,6 @@ def convert_weights(args): ...@@ -442,7 +440,6 @@ def convert_weights(args):
converted_weights, converted_weights,
w_bit=args.bits, w_bit=args.bits,
target_keys=args.target_keys, target_keys=args.target_keys,
min_params=args.min_params,
key_idx=args.key_idx, key_idx=args.key_idx,
ignore_key=args.ignore_key, ignore_key=args.ignore_key,
dtype=args.dtype, dtype=args.dtype,
...@@ -522,7 +519,7 @@ def convert_weights(args): ...@@ -522,7 +519,7 @@ def convert_weights(args):
def copy_non_weight_files(source_dir, target_dir): def copy_non_weight_files(source_dir, target_dir):
ignore_extensions = [".pth", ".pt", ".safetensors"] ignore_extensions = [".pth", ".pt", ".safetensors", ".index.json"]
logger.info(f"Start copying non-weighted files and subdirectories...") logger.info(f"Start copying non-weighted files and subdirectories...")
...@@ -575,12 +572,6 @@ def main(): ...@@ -575,12 +572,6 @@ def main():
# Quantization # Quantization
parser.add_argument("--quantized", action="store_true") parser.add_argument("--quantized", action="store_true")
parser.add_argument("--bits", type=int, default=8, choices=[8], help="Quantization bit width") parser.add_argument("--bits", type=int, default=8, choices=[8], help="Quantization bit width")
parser.add_argument(
"--min_params",
type=int,
default=1000000,
help="Minimum parameters to consider for quantization",
)
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,
...@@ -595,47 +586,48 @@ def main(): ...@@ -595,47 +586,48 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
if args.dtype == "torch.int8": if args.quantized:
args.dtype = torch.int8 if args.dtype == "torch.int8":
elif args.dtype == "torch.float8_e4m3fn": args.dtype = torch.int8
args.dtype = torch.float8_e4m3fn elif args.dtype == "torch.float8_e4m3fn":
else: args.dtype = torch.float8_e4m3fn
raise ValueError(f"Not support dtype :{args.dtype}") else:
raise ValueError(f"Not support dtype :{args.dtype}")
model_type_keys_map = {
"wan_dit": { model_type_keys_map = {
"key_idx": 2, "wan_dit": {
"target_keys": ["self_attn", "cross_attn", "ffn"], "key_idx": 2,
"ignore_key": None, "target_keys": ["self_attn", "cross_attn", "ffn"],
}, "ignore_key": None,
"hunyuan_dit": { },
"key_idx": 2, "hunyuan_dit": {
"target_keys": [ "key_idx": 2,
"img_mod", "target_keys": [
"img_attn_qkv", "img_mod",
"img_attn_proj", "img_attn_qkv",
"img_mlp", "img_attn_proj",
"txt_mod", "img_mlp",
"txt_attn_qkv", "txt_mod",
"txt_attn_proj", "txt_attn_qkv",
"txt_mlp", "txt_attn_proj",
"linear1", "txt_mlp",
"linear2", "linear1",
"modulation", "linear2",
], "modulation",
"ignore_key": None, ],
}, "ignore_key": None,
"wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None}, },
"wan_clip": { "wan_t5": {"key_idx": 2, "target_keys": ["attn", "ffn"], "ignore_key": None},
"key_idx": 3, "wan_clip": {
"target_keys": ["attn", "mlp"], "key_idx": 3,
"ignore_key": "textual", "target_keys": ["attn", "mlp"],
}, "ignore_key": "textual",
} },
}
args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.key_idx = model_type_keys_map[args.model_type]["key_idx"] args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"] args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
if os.path.isfile(args.output): if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file") raise ValueError("Output path must be a directory, not a file")
......
...@@ -33,7 +33,7 @@ python converter.py \ ...@@ -33,7 +33,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit --model_type wan_dit
...@@ -44,7 +44,7 @@ python converter.py \ ...@@ -44,7 +44,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit
...@@ -57,7 +57,7 @@ python converter.py \ ...@@ -57,7 +57,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_int8 \ --output_name hunyuan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type hunyuan_dit --model_type hunyuan_dit
...@@ -68,7 +68,7 @@ python converter.py \ ...@@ -68,7 +68,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit
......
...@@ -33,7 +33,7 @@ python converter.py \ ...@@ -33,7 +33,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type wan_dit --model_type wan_dit
...@@ -44,7 +44,7 @@ python converter.py \ ...@@ -44,7 +44,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \ --source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit
...@@ -57,7 +57,7 @@ python converter.py \ ...@@ -57,7 +57,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext ..safetensors \
--output_name hunyuan_int8 \ --output_name hunyuan_int8 \
--dtype torch.int8 \ --dtype torch.int8 \
--model_type hunyuan_dit --model_type hunyuan_dit
...@@ -68,7 +68,7 @@ python converter.py \ ...@@ -68,7 +68,7 @@ python converter.py \
--quantized \ --quantized \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \ --source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .safetensors \
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment