neurochem-test.py 2.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os
import torch
import torchani
import ignite
import pickle
import argparse

# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset. The path can be a hdf5 file or \
                    a directory containing hdf5 files. It can also be a file \
                    dumped by pickle.')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
17
18
19
parser.add_argument('--batch_size',
                    help='Number of conformations of each batch',
                    default=1024, type=int)
20
21
parser.add_argument('--const_file',
                    help='File storing constants',
22
                    default=torchani.neurochem.buildin_const_file)
23
24
parser.add_argument('--sae_file',
                    help='File storing self atomic energies',
25
                    default=torchani.neurochem.buildin_sae_file)
26
27
28
29
30
31
32
33
34
35
parser.add_argument('--network_dir',
                    help='Directory or prefix of directories storing networks',
                    default=None)
parser.add_argument('--ensemble',
                    help='Number of models in ensemble',
                    default=False)
parser = parser.parse_args()

# load modules and datasets
device = torch.device(parser.device)
36
37
38
39
40
41
consts = torchani.neurochem.Constants(parser.const_file)
sae = torchani.neurochem.load_sae(parser.sae_file)
aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species,
                                   from_=parser.network_dir,
                                   ensemble=parser.ensemble)
42
43
model = torch.nn.Sequential(aev_computer, nn)
container = torchani.training.Container({'energies': model})
44
45
46
container = container.to(device)

# load datasets
47
shift_energy = torchani.EnergyShifter(consts.species, sae)
48
49
50
if parser.dataset_path.endswith('.h5') or \
   parser.dataset_path.endswith('.hdf5') or \
   os.path.isdir(parser.dataset_path):
51
    dataset = torchani.training.BatchedANIDataset(
52
        parser.dataset_path, consts.species, parser.batch_size,
53
        device=device, transform=[shift_energy.subtract_from_dataset])
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    datasets = [dataset]
else:
    with open(parser.dataset_path, 'rb') as f:
        datasets = pickle.load(f)
        if not isinstance(datasets, list) and not isinstance(datasets, tuple):
            datasets = [datasets]


# prepare evaluator
def hartree2kcal(x):
    return 627.509 * x


for dataset in datasets:
    evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
69
        'RMSE': torchani.training.RMSEMetric('energies')
70
    })
71
    evaluator.run(dataset)
72
73
74
    metrics = evaluator.state.metrics
    rmse = hartree2kcal(metrics['RMSE'])
    print(rmse, 'kcal/mol')