"vscode:/vscode.git/clone" did not exist on "03886917bd59f12a1420a99150997732ffea52da"
neurochem-test.py 2.9 KB
Newer Older
1
2
3
4
5
6
7
import os
import torch
import torchani
import ignite
import pickle
import argparse

Gao, Xiang's avatar
Gao, Xiang committed
8

9
ani1x = torchani.models.ANI1x()
Gao, Xiang's avatar
Gao, Xiang committed
10

11
12
13
14
15
16
17
18
19
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset. The path can be a hdf5 file or \
                    a directory containing hdf5 files. It can also be a file \
                    dumped by pickle.')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
20
21
22
parser.add_argument('--batch_size',
                    help='Number of conformations of each batch',
                    default=1024, type=int)
23
24
parser.add_argument('--const_file',
                    help='File storing constants',
25
                    default=ani1x.const_file)
26
27
parser.add_argument('--sae_file',
                    help='File storing self atomic energies',
28
                    default=ani1x.sae_file)
29
30
parser.add_argument('--network_dir',
                    help='Directory or prefix of directories storing networks',
31
                    default=ani1x.ensemble_prefix + '0/networks')
32
33
parser.add_argument('--compare_with',
                    help='The TorchANI model to compare with', default=None)
34
35
36
37
parser = parser.parse_args()

# load modules and datasets
device = torch.device(parser.device)
38
consts = torchani.neurochem.Constants(parser.const_file)
39
shift_energy = torchani.neurochem.load_sae(parser.sae_file)
40
aev_computer = torchani.AEVComputer(**consts)
Gao, Xiang's avatar
Gao, Xiang committed
41
nn = torchani.neurochem.load_model(consts.species, parser.network_dir)
42
model = torch.nn.Sequential(aev_computer, nn)
43
container = torchani.ignite.Container({'energies': model})
44
45
46
47
48
49
container = container.to(device)

# load datasets
if parser.dataset_path.endswith('.h5') or \
   parser.dataset_path.endswith('.hdf5') or \
   os.path.isdir(parser.dataset_path):
50
    dataset = torchani.data.load_ani_dataset(
51
        parser.dataset_path, consts.species_to_tensor, parser.batch_size,
52
        device=device, transform=[shift_energy.subtract_from_dataset])
53
54
55
56
57
58
59
60
61
62
63
64
65
    datasets = [dataset]
else:
    with open(parser.dataset_path, 'rb') as f:
        datasets = pickle.load(f)
        if not isinstance(datasets, list) and not isinstance(datasets, tuple):
            datasets = [datasets]


# prepare evaluator
def hartree2kcal(x):
    return 627.509 * x


66
def evaluate(dataset, container):
67
    evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
68
        'RMSE': torchani.ignite.RMSEMetric('energies')
69
    })
70
    evaluator.run(dataset)
71
72
73
    metrics = evaluator.state.metrics
    rmse = hartree2kcal(metrics['RMSE'])
    print(rmse, 'kcal/mol')
74
75
76
77
78
79
80
81
82
83
84


for dataset in datasets:
    evaluate(dataset, container)


if parser.compare_with is not None:
    nn.load_state_dict(torch.load(parser.compare_with))
    print('TorchANI results:')
    for dataset in datasets:
        evaluate(dataset, container)