pth_to_pickle.py 1.18 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#!/usr/bin/env python3

from argparse import ArgumentParser
import pickle
import numpy as np
import torch

parser = ArgumentParser(description="Convert a pytorch (.pth) file to a pickled dictionary of numpy arrays. "
                                    "The dictionary will have the following format: \n"
                                    "{pytorch param name: numpy array}")
parser.add_argument('input_file', type=str, help='input pytorch .pth file')
parser.add_argument('output_file', type=str, help='output pickle file')
parser.add_argument('-v', '--verbose', action='store_true',
                    help='print parameters names and statistics')
args = parser.parse_args()

dict_out = {}
pth_input = torch.load(open(args.input_file, 'rb'))

for key, value in pth_input.items():
    dict_out[key] = value.data.numpy()

if args.verbose:
    print("name, dtype, mean, std, min, max")
    for key, value in dict_out.items():
        t_mean = np.mean(value)
        t_std = np.std(value)
        t_min = np.min(value)
        t_max = np.max(value)
        print(f"{key}, {value.dtype}, {value.shape}, {t_mean:0.3}, {t_std:0.3}, {t_min:0.3}, {t_max:0.3}")

pickle.dump(dict_out, open(args.output_file, 'wb'))