Unverified Commit d3ae0788 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

remove explicit device and dtype (#44)

parent 8c493a6e
......@@ -8,8 +8,8 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat') # noqa: E501
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.SortedAEV(const_file=const_file, device=device)
prepare = torchani.PrepareInput(aev_computer.species, aev_computer.device)
aev_computer = torchani.SortedAEV(const_file=const_file)
prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8)
model = torch.nn.Sequential(prepare, aev_computer, nn)
......@@ -20,8 +20,6 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
dtype=aev_computer.dtype,
device=aev_computer.device,
requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H']
......
......@@ -18,9 +18,9 @@ def atomic():
def get_or_create_model(filename, benchmark=False,
device=torchani.default_device):
aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device)
prepare = torchani.PrepareInput(aev_computer.species, aev_computer.device)
device=torch.device('cpu')):
aev_computer = torchani.SortedAEV(benchmark=benchmark)
prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.CustomModel(
reducer=torch.sum,
benchmark=benchmark,
......
......@@ -8,6 +8,8 @@ import timeit
import tensorboardX
import math
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
chunk_size = 256
batch_chunks = 4
dataset_path = sys.argv[1]
......@@ -20,11 +22,11 @@ start = timeit.default_timer()
shift_energy = torchani.EnergyShifter()
training, validation, testing = torchani.data.load_or_create(
dataset_checkpoint, dataset_path, chunk_size,
dataset_checkpoint, dataset_path, chunk_size, device=device,
transform=[shift_energy.dataset_subtract_sae])
training = torchani.data.dataloader(training, batch_chunks)
validation = torchani.data.dataloader(validation, batch_chunks)
nnp = model.get_or_create_model(model_checkpoint)
nnp = model.get_or_create_model(model_checkpoint, device=device)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters())
......
......@@ -6,15 +6,17 @@ import timeit
import model
import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
chunk_size = 256
batch_chunks = 4
dataset_path = sys.argv[1]
shift_energy = torchani.EnergyShifter()
dataset = torchani.data.ANIDataset(
dataset_path, chunk_size,
dataset_path, chunk_size, device=device,
transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True)
nnp = model.get_or_create_model('/tmp/model.pt', True, device=device)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters())
......
......@@ -10,12 +10,11 @@ N = 97
class TestAEV(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype):
aev_computer = torchani.SortedAEV(dtype=dtype,
device=torch.device('cpu'))
def setUp(self):
aev_computer = torchani.SortedAEV()
self.radial_length = aev_computer.radial_length
self.aev = torch.nn.Sequential(
torchani.PrepareInput(aev_computer.species, aev_computer.device),
torchani.PrepareInput(aev_computer.species),
aev_computer
)
self.tolerance = 1e-5
......
......@@ -12,17 +12,14 @@ if sys.version_info.major >= 3:
path = os.path.join(path, '../dataset')
chunksize = 32
batch_chunks = 32
dtype = torch.float32
device = torch.device('cpu')
class TestBatch(unittest.TestCase):
def testBatchLoadAndInference(self):
ds = torchani.data.ANIDataset(path, chunksize, device=device)
ds = torchani.data.ANIDataset(path, chunksize)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
model = torch.nn.Sequential(prepare, aev_computer, nnp)
batch_nnp = torchani.models.BatchModel(model)
......
......@@ -6,14 +6,10 @@ import copy
class TestBenchmark(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device):
self.dtype = dtype
self.device = device
def setUp(self):
self.conformations = 100
self.species = list('HHCCNNOO')
self.coordinates = torch.randn(
self.conformations, 8, 3, dtype=dtype, device=device)
self.coordinates = torch.randn(self.conformations, 8, 3)
self.count = 100
def _testModule(self, run_module, result_module, asserts):
......@@ -82,9 +78,8 @@ class TestBenchmark(unittest.TestCase):
self.assertEqual(result_module.timers[i], 0)
def testAEV(self):
aev_computer = torchani.SortedAEV(
benchmark=True, dtype=self.dtype, device=self.device)
prepare = torchani.PrepareInput(aev_computer.species, self.device)
aev_computer = torchani.SortedAEV(benchmark=True)
prepare = torchani.PrepareInput(aev_computer.species)
run_module = torch.nn.Sequential(prepare, aev_computer)
self._testModule(run_module, aev_computer, [
'terms and indices>radial terms',
......@@ -95,11 +90,10 @@ class TestBenchmark(unittest.TestCase):
])
def testANIModel(self):
aev_computer = torchani.SortedAEV(
dtype=self.dtype, device=self.device)
prepare = torchani.PrepareInput(aev_computer.species, self.device)
model = torchani.models.NeuroChemNNP(
aev_computer.species, benchmark=True).to(self.device)
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.NeuroChemNNP(aev_computer.species,
benchmark=True)
run_module = torch.nn.Sequential(prepare, aev_computer, model)
self._testModule(run_module, model, ['forward'])
......
......@@ -11,13 +11,10 @@ N = 97
class TestEnergies(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device):
def setUp(self):
self.tolerance = 5e-5
aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu'))
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
......
......@@ -18,8 +18,8 @@ class TestEnsemble(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True)
n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix
aev = torchani.SortedAEV(device=torch.device('cpu'))
prepare = torchani.PrepareInput(aev.species, aev.device)
aev = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev.species)
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
ensemble = torch.nn.Sequential(prepare, aev, ensemble)
models = [torchani.models.
......
......@@ -10,13 +10,10 @@ N = 97
class TestForce(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device):
def setUp(self):
self.tolerance = 1e-5
aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu'))
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
......
......@@ -13,21 +13,17 @@ if sys.version_info.major >= 3:
path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
chunksize = 4
threshold = 1e-5
dtype = torch.float32
device = torch.device('cpu')
class TestIgnite(unittest.TestCase):
def testIgnite(self):
shift_energy = torchani.EnergyShifter()
ds = torchani.data.ANIDataset(
path, chunksize, device=device,
transform=[shift_energy.dataset_subtract_sae])
path, chunksize, transform=[shift_energy.dataset_subtract_sae])
ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
class Flatten(torch.nn.Module):
......
......@@ -4,10 +4,9 @@ from . import data
from . import ignite
from .aev import SortedAEV, PrepareInput
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble, default_dtype, default_device
buildin_model_prefix, buildin_ensemble
__all__ = ['PrepareInput', 'SortedAEV', 'EnergyShifter',
'models', 'data', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble',
'default_dtype', 'default_device']
'buildin_model_prefix', 'buildin_ensemble']
import torch
import itertools
import math
from .env import buildin_const_file, default_dtype, default_device
from .env import buildin_const_file
from .benchmarked import BenchmarkedModule
class AEVComputer(BenchmarkedModule):
__constants__ = ['Rcr', 'Rca', 'dtype', 'device', 'radial_sublength',
'radial_length', 'angular_sublength', 'angular_length',
'aev_length']
__constants__ = ['Rcr', 'Rca', 'radial_sublength', 'radial_length',
'angular_sublength', 'angular_length', 'aev_length']
"""Base class of various implementations of AEV computer
......@@ -16,11 +15,6 @@ class AEVComputer(BenchmarkedModule):
----------
benchmark : boolean
Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str
The name of the original file that stores constant.
Rcr, Rca : float
......@@ -39,15 +33,12 @@ class AEVComputer(BenchmarkedModule):
The length of full aev
"""
def __init__(self, benchmark=False, dtype=default_dtype,
device=default_device, const_file=buildin_const_file):
def __init__(self, benchmark=False, const_file=buildin_const_file):
super(AEVComputer, self).__init__(benchmark)
self.dtype = dtype
self.const_file = const_file
self.device = device
# load constants from const file
const = {}
with open(const_file) as f:
for i in f:
try:
......@@ -60,8 +51,8 @@ class AEVComputer(BenchmarkedModule):
'ShfZ', 'EtaA', 'ShfA']:
value = [float(x.strip()) for x in value.replace(
'[', '').replace(']', '').split(',')]
value = torch.tensor(value, dtype=dtype, device=device)
setattr(self, name, value)
value = torch.tensor(value)
const[name] = value
elif name == 'Atyp':
value = [x.strip() for x in value.replace(
'[', '').replace(']', '').split(',')]
......@@ -70,10 +61,11 @@ class AEVComputer(BenchmarkedModule):
raise ValueError('unable to parse const file')
# Compute lengths
self.radial_sublength = self.EtaR.shape[0] * self.ShfR.shape[0]
self.radial_sublength = const['EtaR'].shape[0] * const['ShfR'].shape[0]
self.radial_length = len(self.species) * self.radial_sublength
self.angular_sublength = self.EtaA.shape[0] * \
self.Zeta.shape[0] * self.ShfA.shape[0] * self.ShfZ.shape[0]
self.angular_sublength = const['EtaA'].shape[0] * \
const['Zeta'].shape[0] * const['ShfA'].shape[0] * \
const['ShfZ'].shape[0]
species = len(self.species)
self.angular_length = int(
(species * (species + 1)) / 2) * self.angular_sublength
......@@ -81,13 +73,17 @@ class AEVComputer(BenchmarkedModule):
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self.EtaR = self.EtaR.view(-1, 1)
self.ShfR = self.ShfR.view(1, -1)
const['EtaR'] = const['EtaR'].view(-1, 1)
const['ShfR'] = const['ShfR'].view(1, -1)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self.EtaA = self.EtaA.view(-1, 1, 1, 1)
self.Zeta = self.Zeta.view(1, -1, 1, 1)
self.ShfA = self.ShfA.view(1, 1, -1, 1)
self.ShfZ = self.ShfZ.view(1, 1, 1, -1)
const['EtaA'] = const['EtaA'].view(-1, 1, 1, 1)
const['Zeta'] = const['Zeta'].view(1, -1, 1, 1)
const['ShfA'] = const['ShfA'].view(1, 1, -1, 1)
const['ShfZ'] = const['ShfZ'].view(1, 1, 1, -1)
# register buffers
for i in const:
self.register_buffer(i, const[i])
def forward(self, coordinates_species):
"""Compute AEV from coordinates and species
......@@ -112,18 +108,19 @@ class AEVComputer(BenchmarkedModule):
class PrepareInput(torch.nn.Module):
def __init__(self, species, device):
def __init__(self, species):
super(PrepareInput, self).__init__()
self.species = species
self.device = device
def species_to_tensor(self, species):
def species_to_tensor(self, species, device):
"""Convert species list into a long tensor.
Parameters
----------
species : list
List of string for the species of each atoms.
device : torch.device
The device to store tensor
Returns
-------
......@@ -133,7 +130,7 @@ class PrepareInput(torch.nn.Module):
"""
indices = {self.species[i]: i for i in range(len(self.species))}
values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=self.device)
return torch.tensor(values, dtype=torch.long, device=device)
def sort_by_species(self, species, *tensors):
"""Sort the data by its species according to the order in `self.species`
......@@ -158,7 +155,7 @@ class PrepareInput(torch.nn.Module):
def forward(self, species_coordinates):
species, coordinates = species_coordinates
species = self.species_to_tensor(species)
species = self.species_to_tensor(species, coordinates.device)
return self.sort_by_species(species, coordinates)
......@@ -203,9 +200,8 @@ class SortedAEV(AEVComputer):
total : total time for computing everything.
"""
def __init__(self, benchmark=False, device=default_device,
dtype=default_dtype, const_file=buildin_const_file):
super(SortedAEV, self).__init__(benchmark, dtype, device, const_file)
def __init__(self, benchmark=False, const_file=buildin_const_file):
super(SortedAEV, self).__init__(benchmark, const_file)
if benchmark:
self.radial_subaev_terms = self._enable_benchmark(
self.radial_subaev_terms, 'radial terms')
......@@ -385,7 +381,7 @@ class SortedAEV(AEVComputer):
storing the mask for each species.
"""
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(len(self.species), device=self.device))
torch.arange(len(self.species), device=self.EtaR.device))
return mask_r
def compute_mask_a(self, species_a, present_species):
......@@ -451,8 +447,10 @@ class SortedAEV(AEVComputer):
atoms = radial_terms.shape[1]
# assemble radial subaev
present_radial_aevs = (radial_terms.unsqueeze(-2)
* mask_r.unsqueeze(-1).type(self.dtype)).sum(-3)
present_radial_aevs = (
radial_terms.unsqueeze(-2) *
mask_r.unsqueeze(-1).type(radial_terms.dtype)
).sum(-3)
"""shape (conformations, atoms, present species, radial_length)"""
radial_aevs = present_radial_aevs.flatten(start_dim=2)
......@@ -466,13 +464,13 @@ class SortedAEV(AEVComputer):
zero_angular_subaev = torch.zeros(
# TODO: can we make stack and cat broadcast?
conformations, atoms, self.angular_sublength,
dtype=self.dtype, device=self.device)
dtype=self.EtaR.dtype, device=self.EtaR.device)
for s1, s2 in itertools.combinations_with_replacement(
range(len(self.species)), 2):
if s1 in rev_indices and s2 in rev_indices:
i1 = rev_indices[s1]
i2 = rev_indices[s2]
mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.dtype)
mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.EtaR.dtype)
subaev = (angular_terms * mask).sum(-2)
else:
subaev = zero_angular_subaev
......
......@@ -2,7 +2,6 @@ from torch.utils.data import Dataset, DataLoader
from os.path import join, isfile, isdir
import os
from .pyanitools import anidataloader
from .env import default_dtype, default_device
import torch
import torch.utils.data as data
import pickle
......@@ -11,7 +10,8 @@ import pickle
class ANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True, properties=['energies'],
transform=(), dtype=default_dtype, device=default_device):
transform=(), dtype=torch.get_default_dtype(),
device=torch.device('cpu')):
super(ANIDataset, self).__init__()
self.path = path
self.chunks_size = chunk_size
......
import pkg_resources
import torch
buildin_const_file = pkg_resources.resource_filename(
......@@ -15,6 +14,3 @@ buildin_model_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train')
buildin_ensemble = 8
default_dtype = torch.float32
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment