quant_adapter.py 2.46 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
import argparse
import sys
from pathlib import Path

helloyongyang's avatar
helloyongyang committed
5
6
7
8
import safetensors
import torch
from safetensors.torch import save_file

gushiqiao's avatar
gushiqiao committed
9
10
sys.path.append(str(Path(__file__).parent.parent.parent))

helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.quant_utils import FloatQuantizer
12
from tools.convert.quant import *
helloyongyang's avatar
helloyongyang committed
13
14


gushiqiao's avatar
gushiqiao committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def main():
    # 获取脚本所在目录
    script_dir = Path(__file__).parent
    project_root = script_dir.parent.parent

    parser = argparse.ArgumentParser(description="Quantize audio adapter model to FP8")
    parser.add_argument(
        "--model_path",
        type=str,
        default=str(project_root / "models" / "SekoTalk-Distill" / "audio_adapter_model.safetensors"),
        help="Path to input model file",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default=str(project_root / "models" / "SekoTalk-Distill-fp8" / "audio_adapter_model_fp8.safetensors"),
        help="Path to output quantized model file",
    )
    args = parser.parse_args()

    model_path = Path(args.model_path)
    output_path = Path(args.output_path)

    output_path.parent.mkdir(parents=True, exist_ok=True)
helloyongyang's avatar
helloyongyang committed
39

gushiqiao's avatar
gushiqiao committed
40
41
42
43
    state_dict = {}
    with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            state_dict[key] = f.get_tensor(key)
helloyongyang's avatar
helloyongyang committed
44

gushiqiao's avatar
gushiqiao committed
45
    new_state_dict = {}
helloyongyang's avatar
helloyongyang committed
46

gushiqiao's avatar
gushiqiao committed
47
48
49
    for key in state_dict.keys():
        if key.startswith("ca") and ".to" in key and "weight" in key:
            print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}")
helloyongyang's avatar
helloyongyang committed
50

51
            ## fp8
gushiqiao's avatar
gushiqiao committed
52
53
54
55
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
            weight = weight.to(torch.float8_e4m3fn)
            weight_scale = weight_scale.to(torch.float32)
helloyongyang's avatar
helloyongyang committed
56

57
58
59
60
61
            ## QuantWeightMxFP4, QuantWeightMxFP6, QuantWeightMxFP8 for mxfp4,mxfp6,mxfp8
            # weight = state_dict[key].to(torch.bfloat16).cuda()
            # quantizer = QuantWeightMxFP4(weight)
            # weight, weight_scale, _ = quantizer.weight_quant_func(weight)

gushiqiao's avatar
gushiqiao committed
62
63
64
65
66
67
            new_state_dict[key] = weight.cpu()
            new_state_dict[key + "_scale"] = weight_scale.cpu()
        else:
            # 不匹配的权重转换为BF16
            print(f"Converting {key} to BF16, dtype: {state_dict[key].dtype}")
            new_state_dict[key] = state_dict[key].to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
68

gushiqiao's avatar
gushiqiao committed
69
70
    save_file(new_state_dict, str(output_path))
    print(f"Quantized model saved to: {output_path}")
helloyongyang's avatar
helloyongyang committed
71
72


gushiqiao's avatar
gushiqiao committed
73
74
if __name__ == "__main__":
    main()