quant_adapter.py 2.52 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
import argparse
import sys
from pathlib import Path

helloyongyang's avatar
helloyongyang committed
5
6
7
8
import safetensors
import torch
from safetensors.torch import save_file

gushiqiao's avatar
gushiqiao committed
9
10
sys.path.append(str(Path(__file__).parent.parent.parent))

helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.quant_utils import FloatQuantizer
12
from tools.convert.quant import *
helloyongyang's avatar
helloyongyang committed
13
14


gushiqiao's avatar
gushiqiao committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def main():
    # 获取脚本所在目录
    script_dir = Path(__file__).parent
    project_root = script_dir.parent.parent

    parser = argparse.ArgumentParser(description="Quantize audio adapter model to FP8")
    parser.add_argument(
        "--model_path",
        type=str,
        default=str(project_root / "models" / "SekoTalk-Distill" / "audio_adapter_model.safetensors"),
        help="Path to input model file",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default=str(project_root / "models" / "SekoTalk-Distill-fp8" / "audio_adapter_model_fp8.safetensors"),
        help="Path to output quantized model file",
    )
    args = parser.parse_args()

    model_path = Path(args.model_path)
    output_path = Path(args.output_path)

    output_path.parent.mkdir(parents=True, exist_ok=True)
helloyongyang's avatar
helloyongyang committed
39

gushiqiao's avatar
gushiqiao committed
40
41
42
43
    state_dict = {}
    with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            state_dict[key] = f.get_tensor(key)
helloyongyang's avatar
helloyongyang committed
44

gushiqiao's avatar
gushiqiao committed
45
    new_state_dict = {}
helloyongyang's avatar
helloyongyang committed
46

gushiqiao's avatar
gushiqiao committed
47
48
49
    for key in state_dict.keys():
        if key.startswith("ca") and ".to" in key and "weight" in key:
            print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}")
helloyongyang's avatar
helloyongyang committed
50

51
            ## fp8
52
            weight = state_dict[key].to(torch.float32).cuda()
gushiqiao's avatar
gushiqiao committed
53
54
55
56
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
            weight = weight.to(torch.float8_e4m3fn)
            weight_scale = weight_scale.to(torch.float32)
helloyongyang's avatar
helloyongyang committed
57

58
59
60
61
62
            ## QuantWeightMxFP4, QuantWeightMxFP6, QuantWeightMxFP8 for mxfp4,mxfp6,mxfp8
            # weight = state_dict[key].to(torch.bfloat16).cuda()
            # quantizer = QuantWeightMxFP4(weight)
            # weight, weight_scale, _ = quantizer.weight_quant_func(weight)

gushiqiao's avatar
gushiqiao committed
63
64
65
66
67
68
            new_state_dict[key] = weight.cpu()
            new_state_dict[key + "_scale"] = weight_scale.cpu()
        else:
            # 不匹配的权重转换为BF16
            print(f"Converting {key} to BF16, dtype: {state_dict[key].dtype}")
            new_state_dict[key] = state_dict[key].to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
69

gushiqiao's avatar
gushiqiao committed
70
71
    save_file(new_state_dict, str(output_path))
    print(f"Quantized model saved to: {output_path}")
helloyongyang's avatar
helloyongyang committed
72
73


gushiqiao's avatar
gushiqiao committed
74
75
if __name__ == "__main__":
    main()