Unverified Commit 8203102c authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Allow neurochem-test.py to compare with TorchANI checkpoint file (#97)

parent 00e135ad
...@@ -29,6 +29,8 @@ parser.add_argument('--sae_file', ...@@ -29,6 +29,8 @@ parser.add_argument('--sae_file',
parser.add_argument('--network_dir', parser.add_argument('--network_dir',
help='Directory or prefix of directories storing networks', help='Directory or prefix of directories storing networks',
default=builtins.ensemble_prefix + '0/networks') default=builtins.ensemble_prefix + '0/networks')
parser.add_argument('--compare_with',
help='The TorchANI model to compare with', default=None)
parser = parser.parse_args() parser = parser.parse_args()
# load modules and datasets # load modules and datasets
...@@ -61,7 +63,7 @@ def hartree2kcal(x): ...@@ -61,7 +63,7 @@ def hartree2kcal(x):
return 627.509 * x return 627.509 * x
for dataset in datasets: def evaluate(dataset, container):
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies') 'RMSE': torchani.ignite.RMSEMetric('energies')
}) })
...@@ -69,3 +71,14 @@ for dataset in datasets: ...@@ -69,3 +71,14 @@ for dataset in datasets:
metrics = evaluator.state.metrics metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE']) rmse = hartree2kcal(metrics['RMSE'])
print(rmse, 'kcal/mol') print(rmse, 'kcal/mol')
for dataset in datasets:
evaluate(dataset, container)
if parser.compare_with is not None:
nn.load_state_dict(torch.load(parser.compare_with))
print('TorchANI results:')
for dataset in datasets:
evaluate(dataset, container)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment