merge_safetensor_gguf.py 8.38 KB
Newer Older
1
2
3
4
5
# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.

import os
# insert the path of the project
import sys
6
# sys.path.insert(0, "/home/azure/ktransformers")
7
8
import argparse
import torch
9
from ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from safetensors import safe_open
from safetensors.torch import save_file
import re
from collections import defaultdict

def read_safetensor_keys_from_folder(folder_path)->dict:
    """    
    :param folder_path: folder path
    :return: key_to_file_map
    """
    # check if the folder path is exist
    if not os.path.exists(folder_path):
        raise FileNotFoundError(f"GGUF dir not found: {folder_path}")
    if os.path.isfile(folder_path):
        folder_path = os.path.dirname(folder_path)
    
    key_to_file_map = {}

    found_safetensor = False
    for root, dirs, files in os.walk(folder_path):
        # sort files
        files = sorted(files)
        for file in files:
            if file.endswith(".safetensors"):
                found_safetensor = True
                file_path = os.path.join(root, file)
                try:
                    with safe_open(file_path, framework="pt") as f:
                        for key in f.keys():
                            if "model.layers.61" in key:
                                # skip MTP layer
                                continue
                            # try:
                            #     if int(key.split('.')[2]) > 4:
                            #         continue
                            # except:
                            #     pass
                            key_to_file_map[key] = file_path
                except Exception as e:
                    print(f"Error reading Safetensor file {file_path}: {e}")
    
    if not found_safetensor:
        raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
    
    return key_to_file_map

tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor

def translate_name(name:str)->str:
    """
    :param name: name of the tensor
    :return: translated name
    """
    name = translate_name_to_gguf(name)
    name = name.replace(".up_proj.", ".ffn_up_exps.")
    name = name.replace(".down_proj.", ".ffn_down_exps.")
    name = name.replace(".gate_proj.", ".ffn_gate_exps.")
    name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias") 
    return name
    

def combine_tensor_sources(safetensor_path:str, gguf_path:str):
    gguf_loader = GGUFLoader(gguf_path)
    gguf_tensor_file_map = gguf_loader.tensor_file_map
    safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
    
    # build a map for the key to the tensor
    # according to the key, we can get the tensor from the file
    
    target_tensor_map = {}
    for key in safetensor_tensor_file_map.keys():
        # for all experts, we use the gguf tensor
        if ".mlp.experts." in key:
            if '.weight_scale_inv' in key:
                continue
            key = '.'.join(key.split('.')[:5]+key.split('.')[-2:])
            translated_key = translate_name(key)
            target_tensor_map[key] = gguf_tensor_file_map[translated_key]
            continue
        
        if any(target_key in key for target_key in tensor_from_gguf):
            target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)]
        else:
            target_tensor_map[key] = safetensor_tensor_file_map[key]
    
    return target_tensor_map, gguf_loader

def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Cache for safetensor file handles and GGUF loaders
    safetensors_cache = {}
    gguf_cache = {}
    
    # Group tensors by layer
    layer_groups = defaultdict(list)
    non_layer_keys = []
    layer_pattern = re.compile(r'\.layers\.(\d+)\.')
    
    for key in target_tensor_map:
        match = layer_pattern.search(key)
        if match:
            layer_num = int(match.group(1))
            layer_groups[layer_num].append(key)
        else:
            non_layer_keys.append(key)
    
    # Calculate total shards
    total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1
    if total_shards == 0:
        raise ValueError("No tensors to save")
    
    shard_idx = 0
    
    # Save non-layer tensors to the first shard if they exist
    if non_layer_keys:
        tensors = {}
        for key in non_layer_keys:
            file_path = target_tensor_map[key]
            tensor = None
            ggml_type = None
            if file_path.endswith('.safetensors'):
                if file_path not in safetensors_cache:
                    safetensors_cache[file_path] = safe_open(file_path, framework='pt')
                f = safetensors_cache[file_path]
                tensor = f.get_tensor(key)
            elif file_path.endswith('.gguf'):
                gguf_name = translate_name(key)
                tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
            else:
                raise ValueError(f"Unsupported file format: {file_path}")
            tensors[translate_name(key)] = tensor
            if ggml_type:
                ggml_type = torch.tensor(ggml_type)
                ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
                tensors[ggml_key] = ggml_type
        
        output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
        print(f"Saving non-layer tensors to {output_file}")
        save_file(tensors, output_file)
        print(tensors.keys())

        shard_idx += 1
    
    # Save each layer's tensors to subsequent shards
    for layer_num in sorted(layer_groups.keys()):
        layer_keys = layer_groups[layer_num]
        tensors = {}
        for key in layer_keys:
            file_path = target_tensor_map[key]
            tensor = None
            ggml_type = None
            if file_path.endswith('.safetensors'):
                if file_path not in safetensors_cache:
                    safetensors_cache[file_path] = safe_open(file_path, framework='pt')
                f = safetensors_cache[file_path]
                tensor = f.get_tensor(key)
                tensor_info = tensor.shape
            elif file_path.endswith('.gguf'):
                gguf_name = translate_name(key)
                tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
                # tensor_info = gguf_loader.tensor_info[gguf_name]
                # ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']
            else:
                raise ValueError(f"Unsupported file format: {file_path}")
            tensors[translate_name(key)] = tensor
            if ggml_type:
                ggml_type = torch.tensor(ggml_type)
                ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
                tensors[ggml_key] = ggml_type
        
        output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
        print(f"Saving layer {layer_num} to {output_file}")
Azure's avatar
Azure committed
184
        # print(tensors.keys())
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        save_file(tensors, output_file)
        shard_idx += 1
    
    return
    
def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files")
    parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
    parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
    parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
    
    # print all the arguments
    print("All the arguments:")
    print(parser.parse_args())
    
    # 解析命令行参数
    args = parser.parse_args()

    safetensor_path = args.safetensor_path
    gguf_path = args.gguf_path
    output_path = args.output_path
    
    target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
    write_combined_tensor(target_tensor_map, output_path, gguf_loader)
    
    return

if __name__ == "__main__":
    main()