sevenn.py 7.91 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import argparse
import os
import sys
import time

from sevenn import __version__

description = 'train a model given the input.yaml'

input_yaml_help = 'input.yaml for training'
mode_help = 'main training script to run. Default is train.'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi'

# Metainfo will be saved to checkpoint
global_config = {
    'version': __version__,
    'when': time.ctime(),
    '_model_type': 'E3_equivariant_model',
}


def run(args):
    """
    main function of sevenn
    """
    import random
    import sys

    import torch
    import torch.distributed as dist

    import sevenn._keys as KEY
    from sevenn.logger import Logger
    from sevenn.parse_input import read_config_yaml
    from sevenn.scripts.train import train, train_v2
    from sevenn.util import unique_filepath

    input_yaml = args.input_yaml
    mode = args.mode
    working_dir = args.working_dir
    log = args.log
    screen = args.screen
    distributed = args.distributed
    distributed_backend = args.distributed_backend
    use_cue = args.enable_cueq

    if use_cue:
        import sevenn.nn.cue_helper

        if not sevenn.nn.cue_helper.is_cue_available():
            raise ImportError('cuEquivariance not installed.')

    if working_dir is None:
        working_dir = os.getcwd()
    elif not os.path.isdir(working_dir):
        os.makedirs(working_dir, exist_ok=True)

    world_size = 1
    if distributed:
        if distributed_backend == 'nccl':
            local_rank = int(os.environ['LOCAL_RANK'])
            rank = int(os.environ['RANK'])
            world_size = int(os.environ['WORLD_SIZE'])
        elif distributed_backend == 'mpi':
            local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
            rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
            world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
        else:
            raise ValueError(f'Unknown distributed backend: {distributed_backend}')

        dist.init_process_group(
            backend=distributed_backend, world_size=world_size, rank=rank
        )
    else:
        local_rank, rank, world_size = 0, 0, 1

    log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}')
    with Logger(filename=log_fname, screen=screen, rank=rank) as logger:
        logger.greeting()

        if distributed:
            logger.writeline(
                f'Distributed training enabled, total world size is {world_size}'
            )

        try:
            model_config, train_config, data_config = read_config_yaml(
                input_yaml, return_separately=True
            )
        except Exception as e:
            logger.writeline('Failed to parsing input.yaml')
            logger.error(e)
            sys.exit(1)

        train_config[KEY.IS_DDP] = distributed
        train_config[KEY.DDP_BACKEND] = distributed_backend
        train_config[KEY.LOCAL_RANK] = local_rank
        train_config[KEY.RANK] = rank
        train_config[KEY.WORLD_SIZE] = world_size

        if distributed:
            torch.cuda.set_device(torch.device('cuda', local_rank))

        if use_cue:
            if KEY.CUEQUIVARIANCE_CONFIG not in model_config:
                model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True}
            else:
                model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})

        logger.print_config(model_config, data_config, train_config)
        # don't have to distinguish configs inside program
        global_config.update(model_config)
        global_config.update(train_config)
        global_config.update(data_config)

        # Not implemented
        if global_config[KEY.DTYPE] == 'double':
            raise Exception('double precision is not implemented yet')
            # torch.set_default_dtype(torch.double)

        seed = global_config[KEY.RANDOM_SEED]
        random.seed(seed)
        torch.manual_seed(seed)

        # run train
        if mode == 'train_v1':
            train(global_config, working_dir)
        elif mode == 'train_v2':
            train_v2(global_config, working_dir)


def cmd_parser_train(parser):
    ag = parser
    ag.add_argument('input_yaml', help=input_yaml_help, type=str)
    ag.add_argument(
        '-m',
        '--mode',
        choices=['train_v1', 'train_v2'],
        default='train_v2',
        help=mode_help,
        type=str,
    )
    ag.add_argument(
        '-cueq',
        '--enable_cueq',
        help='(Not stable!) use cuEquivariance for training',
        action='store_true',
    )
    ag.add_argument(
        '-w',
        '--working_dir',
        nargs='?',
        const=os.getcwd(),
        help=working_dir_help,
        type=str,
    )
    ag.add_argument(
        '-l',
        '--log',
        default='log.sevenn',
        help='name of logfile, default is log.sevenn',
        type=str,
    )
    ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
    ag.add_argument(
        '-d', '--distributed', help=distributed_help, action='store_true'
    )
    ag.add_argument(
        '--distributed_backend',
        help=distributed_backend_help,
        type=str,
        default='nccl',
        choices=['nccl', 'mpi'],
    )


def add_parser(subparsers):
    ag = subparsers.add_parser('train', help=description)
    cmd_parser_train(ag)


def set_default_subparser(self, name, args=None, positional_args=0):
    """default subparser selection. Call after setup, just before parse_args()
    name: is the name of the subparser to call by default
    args: if set is the argument list handed to parse_args()

    Hack copied from stack overflow
    """
    subparser_found = False
    for arg in sys.argv[1:]:
        if arg in ['-h', '--help']:  # global help if no subparser
            break
    else:
        for x in self._subparsers._actions:
            if not isinstance(x, argparse._SubParsersAction):
                continue
            for sp_name in x._name_parser_map.keys():
                if sp_name in sys.argv[1:]:
                    subparser_found = True
        if not subparser_found:
            # insert default in last position before global positional
            # arguments, this implies no global options are specified after
            # first positional argument
            if args is None:
                sys.argv.insert(len(sys.argv) - positional_args, name)
            else:
                args.insert(len(args) - positional_args, name)


argparse.ArgumentParser.set_default_subparser = set_default_subparser  # type: ignore


def main():
    import sevenn.main.sevenn_cp as checkpoint_cmd
    import sevenn.main.sevenn_get_model as get_model_cmd
    import sevenn.main.sevenn_graph_build as graph_build_cmd
    import sevenn.main.sevenn_inference as inference_cmd
    import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd
    import sevenn.main.sevenn_preset as preset_cmd

    ag = argparse.ArgumentParser(f'SevenNet version={__version__}')

    subparsers = ag.add_subparsers(dest='command', help='Sub-commands')
    add_parser(subparsers)  # add 'train'
    checkpoint_cmd.add_parser(subparsers)
    inference_cmd.add_parser(subparsers)
    graph_build_cmd.add_parser(subparsers)
    preset_cmd.add_parser(subparsers)
    get_model_cmd.add_parser(subparsers)
    patch_lammps_cmd.add_parser(subparsers)

    ag.set_default_subparser('train')  # type: ignore
    args = ag.parse_args()

    if args.command is None:  # backward compatibility
        args.command = 'train'

    if args.command == 'train':
        run(args)
    elif args.command == 'preset':
        preset_cmd.run(args)


if __name__ == '__main__':
    main()