sevenn_inference.py 3.12 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import glob
import os
import sys

description = (
    'evaluate sevenn_data/ase readable with a model (checkpoint).'
)
checkpoint_help = 'Checkpoint or pre-trained model name'
target_help = 'Target files to evaluate'


def add_parser(subparsers):
    ag = subparsers.add_parser('inference', help=description, aliases=['inf'])
    add_args(ag)


def add_args(parser):
    ag = parser
    ag.add_argument('checkpoint', type=str, help=checkpoint_help)
    ag.add_argument('targets', type=str, nargs='+', help=target_help)
    ag.add_argument(
        '-d',
        '--device',
        type=str,
        default='auto',
        help='cpu/cuda/cuda:x',
    )
    ag.add_argument(
        '-nw',
        '--nworkers',
        type=int,
        default=1,
        help='Number of cores to build graph, defaults to 1',
    )
    ag.add_argument(
        '-o',
        '--output',
        type=str,
        default='./inference_results',
        help='A directory name to write outputs',
    )
    ag.add_argument(
        '-b',
        '--batch',
        type=int,
        default='4',
        help='batch size, useful for GPU'
    )
    ag.add_argument(
        '-s',
        '--save_graph',
        action='store_true',
        help='Additionally, save preprocessed graph as sevenn_data'
    )
    ag.add_argument(
        '-au',
        '--allow_unlabeled',
        action='store_true',
        help='Allow energy or force unlabeled data'
    )
    ag.add_argument(
        '-m',
        '--modal',
        type=str,
        default=None,
        help='modality for multi-modal inference',
    )
    ag.add_argument(
        '--kwargs',
        nargs=argparse.REMAINDER,
        help='will be passed to reader, or can be used to specify EFS key',
    )


def run(args):
    import torch

    from sevenn.scripts.inference import inference
    from sevenn.util import pretrained_name_to_path

    out = args.output

    if os.path.exists(out):
        raise FileExistsError(f'Directory {out} already exists')

    device = args.device
    if device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    targets = []
    for target in args.targets:
        targets.extend(glob.glob(target))

    if len(targets) == 0:
        print('No targets (data to inference) are found')
        sys.exit(0)

    cp = args.checkpoint
    if not os.path.isfile(cp):
        cp = pretrained_name_to_path(cp)  # raises value error

    fmt_kwargs = {}
    if args.kwargs:
        for kwarg in args.kwargs:
            k, v = kwarg.split('=')
            fmt_kwargs[k] = v

    if args.save_graph and args.allow_unlabeled:
        raise ValueError('save_graph and allow_unlabeled are mutually exclusive')

    inference(
        cp,
        targets,
        out,
        args.nworkers,
        device,
        args.batch,
        args.save_graph,
        args.allow_unlabeled,
        args.modal,
        **fmt_kwargs,
    )


def main(args=None):
    ag = argparse.ArgumentParser(description=description)
    add_args(ag)
    run(ag.parse_args())