inspect_deepspeed_checkpoint.py 3.64 KB
Newer Older
hepj987's avatar
hepj987 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import sys
from pathlib import Path
 
# insert megatron's root dir into sys.path
root_repo_path = str(Path(__file__).resolve().parents[2])
if root_repo_path not in sys.path:
    sys.path.insert(0, root_repo_path)
    
import argparse

from deepspeed.checkpoint import DeepSpeedCheckpoint 


def list_files(file_list, tag):
    print(f'Listing files: {tag}')
    for i, file in enumerate(file_list):
        print(f'{i+1}: {file}')


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--folder',
                        default=None,
                        type=str,
                        help='DeepSpeed Checkpoint folder')
    parser.add_argument('--target_tp',
                        default=None,
                        type=int,
                        help='Target TP degree')
    parser.add_argument('--target_pp',
                        default=None,
                        type=int,
                        help='Target PP degree')
    args = parser.parse_args()
    print(f'args = {args}')
    return args


def show_input_files(ds_checkpoint):
    list_files(ds_checkpoint.file_list, 'all')
    list_files(ds_checkpoint.zero_files, 'zero')
    list_files(ds_checkpoint.layer_files, 'layer')
    list_files(ds_checkpoint.mp_rank_files, 'mp rank')


def show_simple_state(ds_checkpoint):
    print(f'layer keys = {ds_checkpoint.layer_keys}')
    print(f'layer count = {ds_checkpoint.layer_count}')

    print(
        f'tp_degree_count = {ds_checkpoint.original_tp_degree} ------> {ds_checkpoint.tp_degree}'
    )
    print(
        f'pp_degree_count = {ds_checkpoint.original_pp_degree} ------> {ds_checkpoint.pp_degree}'
    )
    print(f'dp_degree_count = {ds_checkpoint.dp_degree}')
    ds_checkpoint.old_2d_map.print_data('old 2d map ==>')
    ds_checkpoint.new_2d_map.print_data('new 2d map ==>')


def show_mappings(ds_checkpoint):
    ds_checkpoint.show_pp_tranformer_map()
    ds_checkpoint.show_transformer_file_map()
    ds_checkpoint.show_tp_embedding_map()
    ds_checkpoint.show_tp_final_norm_map()
    ds_checkpoint.show_2d_mapping()


def show_state_summary(tag, sd):
    summary = {k: v.shape for k, v in sd.items()}
    print(f'{tag} = {summary}')


def show_embedding_states(ds_checkpoint):
    for i in range(0, ds_checkpoint.tp_degree):
        sd = ds_checkpoint.get_embedding_state(i)
        show_state_summary(f'embedding[{i}]', sd)


def show_final_norm_states(ds_checkpoint):
    for i in range(0, ds_checkpoint.tp_degree):
        sd = ds_checkpoint.get_final_norm_state(i)
        show_state_summary(f'final_norm[{i}]', sd)


def show_transformer_states(ds_checkpoint):
    for i in range(0, ds_checkpoint.tp_degree):
        for j in range(0, ds_checkpoint.pp_degree):
            state_list = ds_checkpoint.get_transformer_state(tp_index=i,
                                                             pp_index=j)
            print(f'tp_pp_rank[{i},{j}] = ')
            for k, sd in enumerate(state_list):
                show_state_summary(f'      block[{k}]', sd)
                print("")


def main():
    print(f'Inspecting DeepSpeed Checkpoint')
    args = parse_arguments()

    ds_checkpoint = DeepSpeedCheckpoint(args.folder, args.target_tp,
                                        args.target_pp)
    ds_checkpoint.validate_files()

    show_simple_state(ds_checkpoint)
    show_input_files(ds_checkpoint)
    show_simple_state(ds_checkpoint)
    show_mappings(ds_checkpoint)
    show_embedding_states(ds_checkpoint)
    show_final_norm_states(ds_checkpoint)
    show_transformer_states(ds_checkpoint)
    checkpoint_args = ds_checkpoint.get_args()
    print(f'checkpoint args = {checkpoint_args}')


if __name__ == "__main__":
    main()