neurochem-test.py 2.55 KB
Newer Older
1
2
3
4
5
6
7
import os
import torch
import torchani
import ignite
import pickle
import argparse

Gao, Xiang's avatar
Gao, Xiang committed
8

9
builtins = torchani.neurochem.Builtins()
Gao, Xiang's avatar
Gao, Xiang committed
10

11
12
13
14
15
16
17
18
19
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset. The path can be a hdf5 file or \
                    a directory containing hdf5 files. It can also be a file \
                    dumped by pickle.')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
20
21
22
parser.add_argument('--batch_size',
                    help='Number of conformations of each batch',
                    default=1024, type=int)
23
24
parser.add_argument('--const_file',
                    help='File storing constants',
25
                    default=builtins.const_file)
26
27
parser.add_argument('--sae_file',
                    help='File storing self atomic energies',
28
                    default=builtins.sae_file)
29
30
parser.add_argument('--network_dir',
                    help='Directory or prefix of directories storing networks',
31
                    default=builtins.ensemble_prefix + '0/networks')
32
33
34
35
parser = parser.parse_args()

# load modules and datasets
device = torch.device(parser.device)
36
consts = torchani.neurochem.Constants(parser.const_file)
37
shift_energy = torchani.neurochem.load_sae(parser.sae_file)
38
aev_computer = torchani.AEVComputer(**consts)
Gao, Xiang's avatar
Gao, Xiang committed
39
nn = torchani.neurochem.load_model(consts.species, parser.network_dir)
40
model = torch.nn.Sequential(aev_computer, nn)
41
container = torchani.ignite.Container({'energies': model})
42
43
44
45
46
47
container = container.to(device)

# load datasets
if parser.dataset_path.endswith('.h5') or \
   parser.dataset_path.endswith('.hdf5') or \
   os.path.isdir(parser.dataset_path):
48
49
    dataset = torchani.data.BatchedANIDataset(
        parser.dataset_path, consts.species_to_tensor, parser.batch_size,
50
        device=device, transform=[shift_energy.subtract_from_dataset])
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    datasets = [dataset]
else:
    with open(parser.dataset_path, 'rb') as f:
        datasets = pickle.load(f)
        if not isinstance(datasets, list) and not isinstance(datasets, tuple):
            datasets = [datasets]


# prepare evaluator
def hartree2kcal(x):
    return 627.509 * x


for dataset in datasets:
    evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
66
        'RMSE': torchani.ignite.RMSEMetric('energies')
67
    })
68
    evaluator.run(dataset)
69
70
71
    metrics = evaluator.state.metrics
    rmse = hartree2kcal(metrics['RMSE'])
    print(rmse, 'kcal/mol')