convert_model_into_blocks.py 1.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import collections
import torch
import os
import json


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str,
                    help="Input file path")
parser.add_argument("--output_model_path", type=str,
                    help="Output folder path")
parser.add_argument("--block_size", type=int, default=10,
                    help="Disi size (GB) of each block.")


args = parser.parse_args()

os.system('mkdir ' + args.output_model_path)

input_model = torch.load(args.input_model_path)

byte_size = args.block_size * 500000000

param_count, file_count, filename_count = 0, 0, 0
index_dict = {"weight_map": {}}

state_dict = collections.OrderedDict()
filename = f"tencentpretrain_model-0.bin"
for k, v in input_model.items():
    state_dict[k] = v
    index_dict["weight_map"][k] = filename
    param_count += v.numel()
    file_count += v.numel()
    if file_count > byte_size:
        torch.save(state_dict, os.path.join(args.output_model_path, filename))
        state_dict = collections.OrderedDict()
        filename_count += 1
        filename = f"tencentpretrain_model-"+str(filename_count)+".bin"
        file_count = 0

if len(state_dict) > 0:
    torch.save(state_dict, os.path.join(args.output_model_path, filename))

index_dict["metadata"] = {"total_size": param_count * 2}
with open(os.path.join(args.output_model_path, "tencentpretrain_model.bin.index.json"), "w") as f:
    json.dump(index_dict, f)