Unverified Commit 5bb66915 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

fix sort_by_species (#57)

parent e5439f3d
......@@ -7,31 +7,29 @@ import numpy
import torchani
import pickle
from torchani import buildin_const_file, buildin_sae_file, \
buildin_network_dir, default_dtype, default_device
buildin_network_dir
import torchani.pyanitools
path = os.path.dirname(os.path.realpath(__file__))
conv_au_ev = 27.21138505
class NeuroChem (torchani.aev_base.AEVComputer):
class NeuroChem (torchani.aev.AEVComputer):
def __init__(self, dtype=default_dtype, device=default_device,
const_file=buildin_const_file, sae_file=buildin_sae_file,
def __init__(self, const_file=buildin_const_file,
sae_file=buildin_sae_file,
network_dir=buildin_network_dir):
super(NeuroChem, self).__init__(False, dtype, device, const_file)
super(NeuroChem, self).__init__(False, const_file)
self.sae_file = sae_file
self.network_dir = network_dir
self.nc = pyNeuroChem.molecule(
self.const_file, self.sae_file, self.network_dir, 0)
def _get_radial_part(self, fullaev):
radial_size = self.radial_length
return fullaev[:, :, :radial_size]
return fullaev[:, :, :self.radial_length]
def _get_angular_part(self, fullaev):
radial_size = self.radial_length
return fullaev[:, :, radial_size:]
return fullaev[:, :, self.radial_length:]
def _per_conformation(self, coordinates, species):
atoms = coordinates.shape[0]
......@@ -50,29 +48,27 @@ class NeuroChem (torchani.aev_base.AEVComputer):
coordinates[i], species) for i in range(conformations)]
aevs, energies, forces = zip(*results)
aevs = torch.from_numpy(numpy.stack(aevs)).type(
self.dtype).to(self.device)
self.EtaR.dtype).to(self.EtaR.device)
energies = torch.from_numpy(numpy.stack(energies)).type(
self.dtype).to(self.device)
self.EtaR.dtype).to(self.EtaR.device)
forces = torch.from_numpy(numpy.stack(forces)).type(
self.dtype).to(self.device)
self.EtaR.dtype).to(self.EtaR.device)
return self._get_radial_part(aevs), \
self._get_angular_part(aevs), \
energies, forces
aev = torchani.SortedAEV(device=torch.device('cpu'))
ncaev = NeuroChem(device=torch.device('cpu'))
ncaev = NeuroChem().to(torch.device('cpu'))
mol_count = 0
for i in [1, 2, 3, 4]:
data_file = os.path.join(
path, '../tests/dataset/ani_gdb_s0{}.h5'.format(i))
path, '../dataset/ani_gdb_s0{}.h5'.format(i))
adl = torchani.pyanitools.anidataloader(data_file)
for data in adl:
coordinates = data['coordinates'][:10, :]
coordinates = torch.from_numpy(coordinates).type(aev.dtype)
coordinates = torch.from_numpy(coordinates).type(ncaev.EtaR.dtype)
species = data['species']
coordinates, species = aev.sort_by_species(coordinates, species)
smiles = ''.join(data['smiles'])
radial, angular, energies, forces = ncaev(coordinates, species)
pickleobj = (coordinates, species, radial, angular, energies, forces)
......
......@@ -151,7 +151,7 @@ class PrepareInput(torch.nn.Module):
new_tensors = []
for t in tensors:
new_tensors.append(t.index_select(1, reverse))
return (species, *tensors)
return (species, *new_tensors)
def forward(self, species_coordinates):
species, coordinates = species_coordinates
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment