"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d2e5cb3c1072ad324d1c9c4bf19be98bc4280282"
Unverified Commit d3ae0788 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

remove explicit device and dtype (#44)

parent 8c493a6e
...@@ -8,8 +8,8 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5. ...@@ -8,8 +8,8 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat') # noqa: E501 sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat') # noqa: E501
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501 network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.SortedAEV(const_file=const_file, device=device) aev_computer = torchani.SortedAEV(const_file=const_file)
prepare = torchani.PrepareInput(aev_computer.species, aev_computer.device) prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir, nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8) ensemble=8)
model = torch.nn.Sequential(prepare, aev_computer, nn) model = torch.nn.Sequential(prepare, aev_computer, nn)
...@@ -20,8 +20,6 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], ...@@ -20,8 +20,6 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.66518241, -0.84461308, 0.20759389], [-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881], [0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]], [0.66091919, -0.16799635, -0.91037834]]],
dtype=aev_computer.dtype,
device=aev_computer.device,
requires_grad=True) requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H'] species = ['C', 'H', 'H', 'H', 'H']
......
...@@ -18,9 +18,9 @@ def atomic(): ...@@ -18,9 +18,9 @@ def atomic():
def get_or_create_model(filename, benchmark=False, def get_or_create_model(filename, benchmark=False,
device=torchani.default_device): device=torch.device('cpu')):
aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device) aev_computer = torchani.SortedAEV(benchmark=benchmark)
prepare = torchani.PrepareInput(aev_computer.species, aev_computer.device) prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.CustomModel( model = torchani.models.CustomModel(
reducer=torch.sum, reducer=torch.sum,
benchmark=benchmark, benchmark=benchmark,
......
...@@ -8,6 +8,8 @@ import timeit ...@@ -8,6 +8,8 @@ import timeit
import tensorboardX import tensorboardX
import math import math
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
chunk_size = 256 chunk_size = 256
batch_chunks = 4 batch_chunks = 4
dataset_path = sys.argv[1] dataset_path = sys.argv[1]
...@@ -20,11 +22,11 @@ start = timeit.default_timer() ...@@ -20,11 +22,11 @@ start = timeit.default_timer()
shift_energy = torchani.EnergyShifter() shift_energy = torchani.EnergyShifter()
training, validation, testing = torchani.data.load_or_create( training, validation, testing = torchani.data.load_or_create(
dataset_checkpoint, dataset_path, chunk_size, dataset_checkpoint, dataset_path, chunk_size, device=device,
transform=[shift_energy.dataset_subtract_sae]) transform=[shift_energy.dataset_subtract_sae])
training = torchani.data.dataloader(training, batch_chunks) training = torchani.data.dataloader(training, batch_chunks)
validation = torchani.data.dataloader(validation, batch_chunks) validation = torchani.data.dataloader(validation, batch_chunks)
nnp = model.get_or_create_model(model_checkpoint) nnp = model.get_or_create_model(model_checkpoint, device=device)
batch_nnp = torchani.models.BatchModel(nnp) batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters()) optimizer = torch.optim.Adam(nnp.parameters())
......
...@@ -6,15 +6,17 @@ import timeit ...@@ -6,15 +6,17 @@ import timeit
import model import model
import tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
chunk_size = 256 chunk_size = 256
batch_chunks = 4 batch_chunks = 4
dataset_path = sys.argv[1] dataset_path = sys.argv[1]
shift_energy = torchani.EnergyShifter() shift_energy = torchani.EnergyShifter()
dataset = torchani.data.ANIDataset( dataset = torchani.data.ANIDataset(
dataset_path, chunk_size, dataset_path, chunk_size, device=device,
transform=[shift_energy.dataset_subtract_sae]) transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, batch_chunks) dataloader = torchani.data.dataloader(dataset, batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True) nnp = model.get_or_create_model('/tmp/model.pt', True, device=device)
batch_nnp = torchani.models.BatchModel(nnp) batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters()) optimizer = torch.optim.Adam(nnp.parameters())
......
...@@ -10,12 +10,11 @@ N = 97 ...@@ -10,12 +10,11 @@ N = 97
class TestAEV(unittest.TestCase): class TestAEV(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype): def setUp(self):
aev_computer = torchani.SortedAEV(dtype=dtype, aev_computer = torchani.SortedAEV()
device=torch.device('cpu'))
self.radial_length = aev_computer.radial_length self.radial_length = aev_computer.radial_length
self.aev = torch.nn.Sequential( self.aev = torch.nn.Sequential(
torchani.PrepareInput(aev_computer.species, aev_computer.device), torchani.PrepareInput(aev_computer.species),
aev_computer aev_computer
) )
self.tolerance = 1e-5 self.tolerance = 1e-5
......
...@@ -12,17 +12,14 @@ if sys.version_info.major >= 3: ...@@ -12,17 +12,14 @@ if sys.version_info.major >= 3:
path = os.path.join(path, '../dataset') path = os.path.join(path, '../dataset')
chunksize = 32 chunksize = 32
batch_chunks = 32 batch_chunks = 32
dtype = torch.float32
device = torch.device('cpu')
class TestBatch(unittest.TestCase): class TestBatch(unittest.TestCase):
def testBatchLoadAndInference(self): def testBatchLoadAndInference(self):
ds = torchani.data.ANIDataset(path, chunksize, device=device) ds = torchani.data.ANIDataset(path, chunksize)
loader = torchani.data.dataloader(ds, batch_chunks) loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device) aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species, prepare = torchani.PrepareInput(aev_computer.species)
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
model = torch.nn.Sequential(prepare, aev_computer, nnp) model = torch.nn.Sequential(prepare, aev_computer, nnp)
batch_nnp = torchani.models.BatchModel(model) batch_nnp = torchani.models.BatchModel(model)
......
...@@ -6,14 +6,10 @@ import copy ...@@ -6,14 +6,10 @@ import copy
class TestBenchmark(unittest.TestCase): class TestBenchmark(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype, def setUp(self):
device=torchani.default_device):
self.dtype = dtype
self.device = device
self.conformations = 100 self.conformations = 100
self.species = list('HHCCNNOO') self.species = list('HHCCNNOO')
self.coordinates = torch.randn( self.coordinates = torch.randn(self.conformations, 8, 3)
self.conformations, 8, 3, dtype=dtype, device=device)
self.count = 100 self.count = 100
def _testModule(self, run_module, result_module, asserts): def _testModule(self, run_module, result_module, asserts):
...@@ -82,9 +78,8 @@ class TestBenchmark(unittest.TestCase): ...@@ -82,9 +78,8 @@ class TestBenchmark(unittest.TestCase):
self.assertEqual(result_module.timers[i], 0) self.assertEqual(result_module.timers[i], 0)
def testAEV(self): def testAEV(self):
aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV(benchmark=True)
benchmark=True, dtype=self.dtype, device=self.device) prepare = torchani.PrepareInput(aev_computer.species)
prepare = torchani.PrepareInput(aev_computer.species, self.device)
run_module = torch.nn.Sequential(prepare, aev_computer) run_module = torch.nn.Sequential(prepare, aev_computer)
self._testModule(run_module, aev_computer, [ self._testModule(run_module, aev_computer, [
'terms and indices>radial terms', 'terms and indices>radial terms',
...@@ -95,11 +90,10 @@ class TestBenchmark(unittest.TestCase): ...@@ -95,11 +90,10 @@ class TestBenchmark(unittest.TestCase):
]) ])
def testANIModel(self): def testANIModel(self):
aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV()
dtype=self.dtype, device=self.device) prepare = torchani.PrepareInput(aev_computer.species)
prepare = torchani.PrepareInput(aev_computer.species, self.device) model = torchani.models.NeuroChemNNP(aev_computer.species,
model = torchani.models.NeuroChemNNP( benchmark=True)
aev_computer.species, benchmark=True).to(self.device)
run_module = torch.nn.Sequential(prepare, aev_computer, model) run_module = torch.nn.Sequential(prepare, aev_computer, model)
self._testModule(run_module, model, ['forward']) self._testModule(run_module, model, ['forward'])
......
...@@ -11,13 +11,10 @@ N = 97 ...@@ -11,13 +11,10 @@ N = 97
class TestEnergies(unittest.TestCase): class TestEnergies(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype, def setUp(self):
device=torchani.default_device):
self.tolerance = 5e-5 self.tolerance = 5e-5
aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV()
dtype=dtype, device=torch.device('cpu')) prepare = torchani.PrepareInput(aev_computer.species)
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp) self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
......
...@@ -18,8 +18,8 @@ class TestEnsemble(unittest.TestCase): ...@@ -18,8 +18,8 @@ class TestEnsemble(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
n = torchani.buildin_ensemble n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix prefix = torchani.buildin_model_prefix
aev = torchani.SortedAEV(device=torch.device('cpu')) aev = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev.species, aev.device) prepare = torchani.PrepareInput(aev.species)
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True) ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
ensemble = torch.nn.Sequential(prepare, aev, ensemble) ensemble = torch.nn.Sequential(prepare, aev, ensemble)
models = [torchani.models. models = [torchani.models.
......
...@@ -10,13 +10,10 @@ N = 97 ...@@ -10,13 +10,10 @@ N = 97
class TestForce(unittest.TestCase): class TestForce(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype, def setUp(self):
device=torchani.default_device):
self.tolerance = 1e-5 self.tolerance = 1e-5
aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV()
dtype=dtype, device=torch.device('cpu')) prepare = torchani.PrepareInput(aev_computer.species)
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp) self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
......
...@@ -13,21 +13,17 @@ if sys.version_info.major >= 3: ...@@ -13,21 +13,17 @@ if sys.version_info.major >= 3:
path = os.path.join(path, '../dataset/ani_gdb_s01.h5') path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
chunksize = 4 chunksize = 4
threshold = 1e-5 threshold = 1e-5
dtype = torch.float32
device = torch.device('cpu')
class TestIgnite(unittest.TestCase): class TestIgnite(unittest.TestCase):
def testIgnite(self): def testIgnite(self):
shift_energy = torchani.EnergyShifter() shift_energy = torchani.EnergyShifter()
ds = torchani.data.ANIDataset( ds = torchani.data.ANIDataset(
path, chunksize, device=device, path, chunksize, transform=[shift_energy.dataset_subtract_sae])
transform=[shift_energy.dataset_subtract_sae])
ds = torch.utils.data.Subset(ds, [0]) ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1) loader = torchani.data.dataloader(ds, 1)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device) aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species, prepare = torchani.PrepareInput(aev_computer.species)
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
class Flatten(torch.nn.Module): class Flatten(torch.nn.Module):
......
...@@ -4,10 +4,9 @@ from . import data ...@@ -4,10 +4,9 @@ from . import data
from . import ignite from . import ignite
from .aev import SortedAEV, PrepareInput from .aev import SortedAEV, PrepareInput
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \ from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble, default_dtype, default_device buildin_model_prefix, buildin_ensemble
__all__ = ['PrepareInput', 'SortedAEV', 'EnergyShifter', __all__ = ['PrepareInput', 'SortedAEV', 'EnergyShifter',
'models', 'data', 'ignite', 'models', 'data', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir', 'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble', 'buildin_model_prefix', 'buildin_ensemble']
'default_dtype', 'default_device']
import torch import torch
import itertools import itertools
import math import math
from .env import buildin_const_file, default_dtype, default_device from .env import buildin_const_file
from .benchmarked import BenchmarkedModule from .benchmarked import BenchmarkedModule
class AEVComputer(BenchmarkedModule): class AEVComputer(BenchmarkedModule):
__constants__ = ['Rcr', 'Rca', 'dtype', 'device', 'radial_sublength', __constants__ = ['Rcr', 'Rca', 'radial_sublength', 'radial_length',
'radial_length', 'angular_sublength', 'angular_length', 'angular_sublength', 'angular_length', 'aev_length']
'aev_length']
"""Base class of various implementations of AEV computer """Base class of various implementations of AEV computer
...@@ -16,11 +15,6 @@ class AEVComputer(BenchmarkedModule): ...@@ -16,11 +15,6 @@ class AEVComputer(BenchmarkedModule):
---------- ----------
benchmark : boolean benchmark : boolean
Whether to enable benchmark Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str const_file : str
The name of the original file that stores constant. The name of the original file that stores constant.
Rcr, Rca : float Rcr, Rca : float
...@@ -39,15 +33,12 @@ class AEVComputer(BenchmarkedModule): ...@@ -39,15 +33,12 @@ class AEVComputer(BenchmarkedModule):
The length of full aev The length of full aev
""" """
def __init__(self, benchmark=False, dtype=default_dtype, def __init__(self, benchmark=False, const_file=buildin_const_file):
device=default_device, const_file=buildin_const_file):
super(AEVComputer, self).__init__(benchmark) super(AEVComputer, self).__init__(benchmark)
self.dtype = dtype
self.const_file = const_file self.const_file = const_file
self.device = device
# load constants from const file # load constants from const file
const = {}
with open(const_file) as f: with open(const_file) as f:
for i in f: for i in f:
try: try:
...@@ -60,8 +51,8 @@ class AEVComputer(BenchmarkedModule): ...@@ -60,8 +51,8 @@ class AEVComputer(BenchmarkedModule):
'ShfZ', 'EtaA', 'ShfA']: 'ShfZ', 'EtaA', 'ShfA']:
value = [float(x.strip()) for x in value.replace( value = [float(x.strip()) for x in value.replace(
'[', '').replace(']', '').split(',')] '[', '').replace(']', '').split(',')]
value = torch.tensor(value, dtype=dtype, device=device) value = torch.tensor(value)
setattr(self, name, value) const[name] = value
elif name == 'Atyp': elif name == 'Atyp':
value = [x.strip() for x in value.replace( value = [x.strip() for x in value.replace(
'[', '').replace(']', '').split(',')] '[', '').replace(']', '').split(',')]
...@@ -70,10 +61,11 @@ class AEVComputer(BenchmarkedModule): ...@@ -70,10 +61,11 @@ class AEVComputer(BenchmarkedModule):
raise ValueError('unable to parse const file') raise ValueError('unable to parse const file')
# Compute lengths # Compute lengths
self.radial_sublength = self.EtaR.shape[0] * self.ShfR.shape[0] self.radial_sublength = const['EtaR'].shape[0] * const['ShfR'].shape[0]
self.radial_length = len(self.species) * self.radial_sublength self.radial_length = len(self.species) * self.radial_sublength
self.angular_sublength = self.EtaA.shape[0] * \ self.angular_sublength = const['EtaA'].shape[0] * \
self.Zeta.shape[0] * self.ShfA.shape[0] * self.ShfZ.shape[0] const['Zeta'].shape[0] * const['ShfA'].shape[0] * \
const['ShfZ'].shape[0]
species = len(self.species) species = len(self.species)
self.angular_length = int( self.angular_length = int(
(species * (species + 1)) / 2) * self.angular_sublength (species * (species + 1)) / 2) * self.angular_sublength
...@@ -81,13 +73,17 @@ class AEVComputer(BenchmarkedModule): ...@@ -81,13 +73,17 @@ class AEVComputer(BenchmarkedModule):
# convert constant tensors to a ready-to-broadcast shape # convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR) # shape convension (..., EtaR, ShfR)
self.EtaR = self.EtaR.view(-1, 1) const['EtaR'] = const['EtaR'].view(-1, 1)
self.ShfR = self.ShfR.view(1, -1) const['ShfR'] = const['ShfR'].view(1, -1)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ) # shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self.EtaA = self.EtaA.view(-1, 1, 1, 1) const['EtaA'] = const['EtaA'].view(-1, 1, 1, 1)
self.Zeta = self.Zeta.view(1, -1, 1, 1) const['Zeta'] = const['Zeta'].view(1, -1, 1, 1)
self.ShfA = self.ShfA.view(1, 1, -1, 1) const['ShfA'] = const['ShfA'].view(1, 1, -1, 1)
self.ShfZ = self.ShfZ.view(1, 1, 1, -1) const['ShfZ'] = const['ShfZ'].view(1, 1, 1, -1)
# register buffers
for i in const:
self.register_buffer(i, const[i])
def forward(self, coordinates_species): def forward(self, coordinates_species):
"""Compute AEV from coordinates and species """Compute AEV from coordinates and species
...@@ -112,18 +108,19 @@ class AEVComputer(BenchmarkedModule): ...@@ -112,18 +108,19 @@ class AEVComputer(BenchmarkedModule):
class PrepareInput(torch.nn.Module): class PrepareInput(torch.nn.Module):
def __init__(self, species, device): def __init__(self, species):
super(PrepareInput, self).__init__() super(PrepareInput, self).__init__()
self.species = species self.species = species
self.device = device
def species_to_tensor(self, species): def species_to_tensor(self, species, device):
"""Convert species list into a long tensor. """Convert species list into a long tensor.
Parameters Parameters
---------- ----------
species : list species : list
List of string for the species of each atoms. List of string for the species of each atoms.
device : torch.device
The device to store tensor
Returns Returns
------- -------
...@@ -133,7 +130,7 @@ class PrepareInput(torch.nn.Module): ...@@ -133,7 +130,7 @@ class PrepareInput(torch.nn.Module):
""" """
indices = {self.species[i]: i for i in range(len(self.species))} indices = {self.species[i]: i for i in range(len(self.species))}
values = [indices[i] for i in species] values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=self.device) return torch.tensor(values, dtype=torch.long, device=device)
def sort_by_species(self, species, *tensors): def sort_by_species(self, species, *tensors):
"""Sort the data by its species according to the order in `self.species` """Sort the data by its species according to the order in `self.species`
...@@ -158,7 +155,7 @@ class PrepareInput(torch.nn.Module): ...@@ -158,7 +155,7 @@ class PrepareInput(torch.nn.Module):
def forward(self, species_coordinates): def forward(self, species_coordinates):
species, coordinates = species_coordinates species, coordinates = species_coordinates
species = self.species_to_tensor(species) species = self.species_to_tensor(species, coordinates.device)
return self.sort_by_species(species, coordinates) return self.sort_by_species(species, coordinates)
...@@ -203,9 +200,8 @@ class SortedAEV(AEVComputer): ...@@ -203,9 +200,8 @@ class SortedAEV(AEVComputer):
total : total time for computing everything. total : total time for computing everything.
""" """
def __init__(self, benchmark=False, device=default_device, def __init__(self, benchmark=False, const_file=buildin_const_file):
dtype=default_dtype, const_file=buildin_const_file): super(SortedAEV, self).__init__(benchmark, const_file)
super(SortedAEV, self).__init__(benchmark, dtype, device, const_file)
if benchmark: if benchmark:
self.radial_subaev_terms = self._enable_benchmark( self.radial_subaev_terms = self._enable_benchmark(
self.radial_subaev_terms, 'radial terms') self.radial_subaev_terms, 'radial terms')
...@@ -385,7 +381,7 @@ class SortedAEV(AEVComputer): ...@@ -385,7 +381,7 @@ class SortedAEV(AEVComputer):
storing the mask for each species. storing the mask for each species.
""" """
mask_r = (species_r.unsqueeze(-1) == mask_r = (species_r.unsqueeze(-1) ==
torch.arange(len(self.species), device=self.device)) torch.arange(len(self.species), device=self.EtaR.device))
return mask_r return mask_r
def compute_mask_a(self, species_a, present_species): def compute_mask_a(self, species_a, present_species):
...@@ -451,8 +447,10 @@ class SortedAEV(AEVComputer): ...@@ -451,8 +447,10 @@ class SortedAEV(AEVComputer):
atoms = radial_terms.shape[1] atoms = radial_terms.shape[1]
# assemble radial subaev # assemble radial subaev
present_radial_aevs = (radial_terms.unsqueeze(-2) present_radial_aevs = (
* mask_r.unsqueeze(-1).type(self.dtype)).sum(-3) radial_terms.unsqueeze(-2) *
mask_r.unsqueeze(-1).type(radial_terms.dtype)
).sum(-3)
"""shape (conformations, atoms, present species, radial_length)""" """shape (conformations, atoms, present species, radial_length)"""
radial_aevs = present_radial_aevs.flatten(start_dim=2) radial_aevs = present_radial_aevs.flatten(start_dim=2)
...@@ -466,13 +464,13 @@ class SortedAEV(AEVComputer): ...@@ -466,13 +464,13 @@ class SortedAEV(AEVComputer):
zero_angular_subaev = torch.zeros( zero_angular_subaev = torch.zeros(
# TODO: can we make stack and cat broadcast? # TODO: can we make stack and cat broadcast?
conformations, atoms, self.angular_sublength, conformations, atoms, self.angular_sublength,
dtype=self.dtype, device=self.device) dtype=self.EtaR.dtype, device=self.EtaR.device)
for s1, s2 in itertools.combinations_with_replacement( for s1, s2 in itertools.combinations_with_replacement(
range(len(self.species)), 2): range(len(self.species)), 2):
if s1 in rev_indices and s2 in rev_indices: if s1 in rev_indices and s2 in rev_indices:
i1 = rev_indices[s1] i1 = rev_indices[s1]
i2 = rev_indices[s2] i2 = rev_indices[s2]
mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.dtype) mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.EtaR.dtype)
subaev = (angular_terms * mask).sum(-2) subaev = (angular_terms * mask).sum(-2)
else: else:
subaev = zero_angular_subaev subaev = zero_angular_subaev
......
...@@ -2,7 +2,6 @@ from torch.utils.data import Dataset, DataLoader ...@@ -2,7 +2,6 @@ from torch.utils.data import Dataset, DataLoader
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
import os import os
from .pyanitools import anidataloader from .pyanitools import anidataloader
from .env import default_dtype, default_device
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import pickle import pickle
...@@ -11,7 +10,8 @@ import pickle ...@@ -11,7 +10,8 @@ import pickle
class ANIDataset(Dataset): class ANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True, properties=['energies'], def __init__(self, path, chunk_size, shuffle=True, properties=['energies'],
transform=(), dtype=default_dtype, device=default_device): transform=(), dtype=torch.get_default_dtype(),
device=torch.device('cpu')):
super(ANIDataset, self).__init__() super(ANIDataset, self).__init__()
self.path = path self.path = path
self.chunks_size = chunk_size self.chunks_size = chunk_size
......
import pkg_resources import pkg_resources
import torch
buildin_const_file = pkg_resources.resource_filename( buildin_const_file = pkg_resources.resource_filename(
...@@ -15,6 +14,3 @@ buildin_model_prefix = pkg_resources.resource_filename( ...@@ -15,6 +14,3 @@ buildin_model_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train') __name__, 'resources/ani-1x_dft_x8ens/train')
buildin_ensemble = 8 buildin_ensemble = 8
default_dtype = torch.float32
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment