inspect_checkpoint.py 1.17 KB
Newer Older
hepj987's avatar
hepj987 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import sys
import torch
import os
from collections import OrderedDict
from pathlib import Path

# insert megatron's root dir into sys.path
root_repo_path = str(Path(__file__).resolve().parents[2])
if root_repo_path not in sys.path:
    sys.path.insert(0, root_repo_path)


def dump_data(datum, name_list=[]):
    if type(datum) in (dict, OrderedDict):
        for k, v in datum.items():
            dump_data(v, name_list + [str(k)])
    elif type(datum) in (list, tuple):
        for v in datum:
            dump_data(v, name_list)
    elif torch.is_tensor(datum):
        prefix = '.'.join(name_list)
        print(f'[tensor] {prefix} = {datum.shape}')
    else:
        #pass
        prefix = '.'.join(name_list)
        print(f'[other] {prefix} = {datum}')


def main():
    if len(sys.argv) < 2:
        print(f'Usage: {sys.argv[0]} <checkpoint file>')
        exit(1)

    ckpt_file = sys.argv[1]
    if not os.path.isfile(ckpt_file):
        print(f'{ckpt_file} is not a valid file')
        exit(1)

    print(f'loading checkpoint file: {ckpt_file}')
    sd = torch.load(ckpt_file, map_location=torch.device('cpu'))
    dump_data(sd)

    quit()


if __name__ == "__main__":
    main()