bf16_cast_channel_int4_v2.py 6.81 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm

import torch
from safetensors.torch import load_file, save_file
from huggingface_hub import snapshot_download

import numpy as np
import matplotlib.pyplot as plt

def get_plot(matrix: torch.Tensor):
    n_rows = matrix.shape[0]
    row_labels = [f"Row_{i}" for i in range(n_rows)]

    # 为每行生成一个独立的图片并保存
    for i in range(n_rows):
        plt.figure(figsize=(8, 4))
        plt.hist(matrix[i, :], bins=20, alpha=0.7, color='green')
        plt.title(f"Distribution of {row_labels[i]}")
        plt.xlabel("Value")
        plt.ylabel("Frequency")
        plt.savefig(f"./result/row_{i}_histogram.png")  # 保存为PNG
        plt.close() 



def weight_quant(tensor: torch.Tensor):
    assert tensor.dim() == 2
    qmax = 127.0 #-127 到 127
    abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0]  # [rows, 1]
    scale = abs_max / qmax  # [rows, 1]
    assert scale.shape == (tensor.shape[0], 1)
    quantized = torch.round(tensor / scale)
    quantized = torch.clamp(quantized, -qmax, qmax)
    return quantized.to(torch.int8), scale.to(torch.float32)

def weight_quantint4(tensor: torch.Tensor):
    assert tensor.dim() == 2
    qmax = 7.0 #-7 到 7
    
    #求绝对值
    abs_value=torch.abs(tensor)
    
    #对绝对值进行排序
    sorted_matrix,_ = torch.sort(abs_value, dim=1)
    k=tensor.shape[1]
    index=int(k*0.95)
    abs_max=sorted_matrix[:,index].reshape(-1,1)
    
    # print("abs_max:",abs_max)
    # print("abs_max.shape:",abs_max.shape)
    #abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0]  # [rows, 1]
    
    scale = abs_max / qmax  # [rows, 1]
    
    assert scale.shape == (tensor.shape[0], 1)
    #量化
    quantized = torch.round(tensor / scale)
    quantized = torch.clamp(quantized, -qmax, qmax).to(torch.int8)
    
    #quantized_int8=quantized+8
    print("quantized_int8:",quantized)
    negative_mask = quantized < 0

    # 2. 对负数取绝对值并加 8
    #quantized[negative_mask] = torch.abs(quantized[negative_mask]) + 16
    quantized[negative_mask] = quantized[negative_mask] #+ 8

    print("quantized_int8_2:",quantized)

    quantized_int8=quantized.to(torch.uint8) 
    
    n, k = quantized.size()
    new_shape = (n, k // 2)
    quantized_int4= torch.empty(new_shape, dtype=torch.int8, device=tensor.device)
    
    a=quantized_int8[..., ::2]
    b = quantized_int8[..., 1::2]
        
    a_4bit = a #& 0x0F
    b_4bit = b & 0x0F
    quantized_int4 = (a_4bit << 4) |  b_4bit
    quantized_int4=quantized_int4.contiguous().to(torch.int8)   
    
    # print("quantized_int4.shape:",quantized_int4.shape)
    # quantized_int4_2=torch.repeat_interleave(quantized_int4, repeats=2, dim=-1)
    # a1= quantized_int4_2[..., ::2]
    # b1= quantized_int4_2[..., 1::2]
    
    # print("a1:",a1)
    # print("b1:",b1)
    
    # a2= a1 & 0xF0
    # b2 = (b1<<4 ) & 0xF0
    # print("a:",a2)
    # print("b:",b2)

    return quantized_int4 , scale.to(torch.float32)



def main(bf16_path, int8_path, model_name="deepseek-ai/DeepSeek-R1"):
    torch.set_default_dtype(torch.bfloat16)
    os.makedirs(int8_path, exist_ok=True)
    model_index_file = os.path.join(int8_path, "model.safetensors.index.json")
    config_file = os.path.join(int8_path, "config.json")

    if not os.path.exists(model_index_file) or not os.path.exists(config_file):
        snapshot_download(
            repo_id=model_name,
            ignore_patterns=["*.safetensors"],
            local_dir=int8_path,
            local_dir_use_symlinks=False
        )
        print(f"model index file and config file downloaded to {int8_path}")

        # modify config.json and save it
        config = json.load(open(config_file))
        # delete quantization_config
        config.pop("quantization_config", None)
        with open(config_file, "w", encoding="utf-8") as f:
            json.dump(config, f, indent=2, ensure_ascii=False, sort_keys=True)
        print(f"config.json modified and saved to {config_file}")

    with open(model_index_file, "r") as f:
        model_index = json.load(f)
    weight_map = model_index["weight_map"]
    scale_count = len([key for key in weight_map.keys() if key.endswith("_scale_inv")])
    
    safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
    safetensor_files.sort()
    quant_count = 0
    new_weight_map = {}
    for safetensor_file in tqdm(safetensor_files):
        file_name = os.path.basename(safetensor_file)
        state_dict = load_file(safetensor_file, device="cuda")
        new_state_dict = {}
        for weight_name, weight in state_dict.items():
            scale_inv_name = f"{weight_name}_scale_inv"
            if scale_inv_name in weight_map:

                print("scale_inv_name:",scale_inv_name)
                
                assert weight.element_size() == 2
                quant_count += 1
                int8_weight, scale_inv = weight_quant(weight)
                new_scale_name = scale_inv_name.replace("_scale_inv", "_scale")
                
                if ".mlp.experts." in weight_name:
                    int4_weight,scale_int4 =weight_quantint4(int8_weight)
                    new_state_dict[weight_name] = int4_weight #int8_weight
                    new_state_dict[new_scale_name] = scale_inv*scale_int4/16
                else:
                    new_state_dict[weight_name] = int8_weight
                    new_state_dict[new_scale_name] = scale_inv
                
                new_weight_map[weight_name] = file_name
                new_weight_map[new_scale_name] = file_name
            else:
                print("nonono")
                new_state_dict[weight_name] = weight
                new_weight_map[weight_name] = file_name
        new_safetensor_file = os.path.join(int8_path, file_name)
        save_file(new_state_dict, new_safetensor_file)
    #assert quant_count == scale_count
    print(f"{quant_count} weights are quantized.")

    # modify model.safetensors.index.json
    with open(model_index_file, "r") as f:
        model_index = json.load(f)
    model_index["weight_map"] = new_weight_map
    with open(model_index_file, "w", encoding="utf-8") as f:
        json.dump(model_index, f, indent=2, ensure_ascii=False, sort_keys=True)
    print(f"model.safetensors.index.json modified and saved to {model_index_file}")
        

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--input-bf16-hf-path", type=str, default="/dataset/llm-models/deepseek-r1/DeepSeek-R1-0528-bf16")
    parser.add_argument("--output-int8-hf-path", type=str, default="/FrameWork/0307/3/modeltrans/DeepSeek-R1-0528-SlimQuant-W4A8")
    parser.add_argument("--model-name", type=str, default="deepseek-ai/DeepSeek-R1")

    args = parser.parse_args()
    main(args.input_bf16_hf_path, args.output_int8_hf_path, args.model_name)
    print("done")