Unverified Commit a9a792a2 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Allow running NeuroChem trainer by `python -m torchani.neurochem.trainer` (#82)

parent c37e6e1f
...@@ -32,6 +32,7 @@ steps: ...@@ -32,6 +32,7 @@ steps:
- python examples/energy_force.py - python examples/energy_force.py
- python examples/neurochem-test.py ./dataset/ani_gdb_s01.h5 - python examples/neurochem-test.py ./dataset/ani_gdb_s01.h5
- python examples/inference-benchmark.py examples/xyz_files/CH4-5.xyz - python examples/inference-benchmark.py examples/xyz_files/CH4-5.xyz
- python -m torchani.neurochem.trainer tests/test_data/inputtrain.ipt dataset/ani_gdb_s01.h5 dataset/ani_gdb_s01.h5
Docs: Docs:
image: '${{BuildTorchANI}}' image: '${{BuildTorchANI}}'
......
...@@ -28,8 +28,8 @@ Utilities ...@@ -28,8 +28,8 @@ Utilities
.. autofunction:: torchani.utils.strip_redundant_padding .. autofunction:: torchani.utils.strip_redundant_padding
NeuroChem Importers NeuroChem Utils
=================== ===============
.. automodule:: torchani.neurochem .. automodule:: torchani.neurochem
.. autoclass:: torchani.neurochem.Constants .. autoclass:: torchani.neurochem.Constants
...@@ -41,6 +41,7 @@ NeuroChem Importers ...@@ -41,6 +41,7 @@ NeuroChem Importers
.. autoclass:: torchani.neurochem.Buildins .. autoclass:: torchani.neurochem.Buildins
.. autoclass:: torchani.neurochem.Trainer .. autoclass:: torchani.neurochem.Trainer
:members: :members:
.. automodule:: torchani.neurochem.trainer
Ignite Helpers Ignite Helpers
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Tools for loading NeuroChem input files.""" """Tools for loading/running NeuroChem input files."""
import pkg_resources import pkg_resources
import torch import torch
...@@ -12,11 +12,11 @@ import ignite ...@@ -12,11 +12,11 @@ import ignite
import math import math
import timeit import timeit
from collections.abc import Mapping from collections.abc import Mapping
from .nn import ANIModel, Ensemble, Gaussian from ..nn import ANIModel, Ensemble, Gaussian
from .utils import EnergyShifter from ..utils import EnergyShifter
from .aev import AEVComputer from ..aev import AEVComputer
from .ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
from .data import BatchedANIDataset from ..data import BatchedANIDataset
class Constants(Mapping): class Constants(Mapping):
...@@ -281,19 +281,20 @@ class Buildins: ...@@ -281,19 +281,20 @@ class Buildins:
""" """
def __init__(self): def __init__(self):
parent_name = '.'.join(__name__.split('.')[:-1])
self.const_file = pkg_resources.resource_filename( self.const_file = pkg_resources.resource_filename(
__name__, parent_name,
'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params') 'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params')
self.consts = Constants(self.const_file) self.consts = Constants(self.const_file)
self.aev_computer = AEVComputer(**self.consts) self.aev_computer = AEVComputer(**self.consts)
self.sae_file = pkg_resources.resource_filename( self.sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat') parent_name, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
self.energy_shifter = load_sae(self.sae_file) self.energy_shifter = load_sae(self.sae_file)
self.ensemble_size = 8 self.ensemble_size = 8
self.ensemble_prefix = pkg_resources.resource_filename( self.ensemble_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train') parent_name, 'resources/ani-1x_dft_x8ens/train')
self.models = load_model_ensemble(self.consts.species, self.models = load_model_ensemble(self.consts.species,
self.ensemble_prefix, self.ensemble_prefix,
self.ensemble_size) self.ensemble_size)
......
# -*- coding: utf-8 -*-
"""Besides running NeuroChem trainer by programming, we can also run it by
``python -m torchani.neurochem.trainer``, use the ``-h`` option for help.
"""
import torch
from . import Trainer
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('config_path',
help='Path of the training config file `.ipt`')
parser.add_argument('training_path',
help='Path of the training set, can be a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('validation_path',
help='Path of the validation set, can be a hdf5 file \
or a directory containing hdf5 files')
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser.add_argument('-d', '--device', help='Device for training',
default=default_device)
parser.add_argument('--tqdm', help='Whether to enable tqdm',
dest='tqdm', action='store_true')
parser.add_argument('--tensorboard',
help='Directory to store tensorboard log files',
default=None)
parser = parser.parse_args()
d = torch.device(parser.device)
trainer = Trainer(parser.config_path, d, parser.tqdm, parser.tensorboard)
trainer.load_data(parser.training_path, parser.validation_path)
trainer.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment