quant_adapter.py 2.2 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
import argparse
import sys
from pathlib import Path

helloyongyang's avatar
helloyongyang committed
5
6
7
8
import safetensors
import torch
from safetensors.torch import save_file

gushiqiao's avatar
gushiqiao committed
9
10
sys.path.append(str(Path(__file__).parent.parent.parent))

helloyongyang's avatar
helloyongyang committed
11
12
13
from lightx2v.utils.quant_utils import FloatQuantizer


gushiqiao's avatar
gushiqiao committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def main():
    # 获取脚本所在目录
    script_dir = Path(__file__).parent
    project_root = script_dir.parent.parent

    parser = argparse.ArgumentParser(description="Quantize audio adapter model to FP8")
    parser.add_argument(
        "--model_path",
        type=str,
        default=str(project_root / "models" / "SekoTalk-Distill" / "audio_adapter_model.safetensors"),
        help="Path to input model file",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default=str(project_root / "models" / "SekoTalk-Distill-fp8" / "audio_adapter_model_fp8.safetensors"),
        help="Path to output quantized model file",
    )
    args = parser.parse_args()

    model_path = Path(args.model_path)
    output_path = Path(args.output_path)

    output_path.parent.mkdir(parents=True, exist_ok=True)
helloyongyang's avatar
helloyongyang committed
38

gushiqiao's avatar
gushiqiao committed
39
40
41
42
    state_dict = {}
    with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            state_dict[key] = f.get_tensor(key)
helloyongyang's avatar
helloyongyang committed
43

gushiqiao's avatar
gushiqiao committed
44
    new_state_dict = {}
helloyongyang's avatar
helloyongyang committed
45

gushiqiao's avatar
gushiqiao committed
46
47
48
    for key in state_dict.keys():
        if key.startswith("ca") and ".to" in key and "weight" in key:
            print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}")
helloyongyang's avatar
helloyongyang committed
49

gushiqiao's avatar
gushiqiao committed
50
51
52
53
54
            weight = state_dict[key].to(torch.float32).cuda()
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
            weight = weight.to(torch.float8_e4m3fn)
            weight_scale = weight_scale.to(torch.float32)
helloyongyang's avatar
helloyongyang committed
55

gushiqiao's avatar
gushiqiao committed
56
57
58
59
60
61
            new_state_dict[key] = weight.cpu()
            new_state_dict[key + "_scale"] = weight_scale.cpu()
        else:
            # 不匹配的权重转换为BF16
            print(f"Converting {key} to BF16, dtype: {state_dict[key].dtype}")
            new_state_dict[key] = state_dict[key].to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
62

gushiqiao's avatar
gushiqiao committed
63
64
    save_file(new_state_dict, str(output_path))
    print(f"Quantized model saved to: {output_path}")
helloyongyang's avatar
helloyongyang committed
65
66


gushiqiao's avatar
gushiqiao committed
67
68
if __name__ == "__main__":
    main()