Unverified Commit d7ef8182 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove energy shifter's dependency on atomic symbols (#75)

parent eb090700
...@@ -10,11 +10,10 @@ network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') ...@@ -10,11 +10,10 @@ network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train')
ensemble = 8 ensemble = 8
consts = torchani.neurochem.Constants(const_file) consts = torchani.neurochem.Constants(const_file)
sae = torchani.neurochem.load_sae(sae_file)
aev_computer = torchani.AEVComputer(**consts) aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model_ensemble(consts.species, network_dir, nn = torchani.neurochem.load_model_ensemble(consts.species, network_dir,
ensemble) ensemble)
shift_energy = torchani.EnergyShifter(consts.species, sae) shift_energy = torchani.neurochem.load_sae(sae_file)
model = torch.nn.Sequential(aev_computer, nn, shift_energy) model = torch.nn.Sequential(aev_computer, nn, shift_energy)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
......
...@@ -31,7 +31,7 @@ parser = parser.parse_args() ...@@ -31,7 +31,7 @@ parser = parser.parse_args()
# load modules and datasets # load modules and datasets
device = torch.device(parser.device) device = torch.device(parser.device)
consts = torchani.neurochem.Constants(parser.const_file) consts = torchani.neurochem.Constants(parser.const_file)
sae = torchani.neurochem.load_sae(parser.sae_file) shift_energy = torchani.neurochem.load_sae(parser.sae_file)
aev_computer = torchani.AEVComputer(**consts) aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species, parser.network_dir) nn = torchani.neurochem.load_model(consts.species, parser.network_dir)
model = torch.nn.Sequential(aev_computer, nn) model = torch.nn.Sequential(aev_computer, nn)
...@@ -39,7 +39,6 @@ container = torchani.training.Container({'energies': model}) ...@@ -39,7 +39,6 @@ container = torchani.training.Container({'energies': model})
container = container.to(device) container = container.to(device)
# load datasets # load datasets
shift_energy = torchani.EnergyShifter(consts.species, sae)
if parser.dataset_path.endswith('.h5') or \ if parser.dataset_path.endswith('.h5') or \
parser.dataset_path.endswith('.hdf5') or \ parser.dataset_path.endswith('.hdf5') or \
os.path.isdir(parser.dataset_path): os.path.isdir(parser.dataset_path):
......
...@@ -63,17 +63,15 @@ class Constants(Mapping): ...@@ -63,17 +63,15 @@ class Constants(Mapping):
def load_sae(filename): def load_sae(filename):
"""Load self energies from NeuroChem sae file""" """Load self energies from NeuroChem sae file"""
self_energies = {} self_energies = []
with open(filename) as f: with open(filename) as f:
for i in f: for i in f:
try: line = [x.strip() for x in i.split('=')]
line = [x.strip() for x in i.split('=')] index = int(line[0].split(',')[1].strip())
name = line[0].split(',')[0].strip() value = float(line[1])
value = float(line[1]) self_energies.append((index, value))
self_energies[name] = value self_energies = [i for _, i in sorted(self_energies)]
except Exception: return EnergyShifter(self_energies)
pass # ignore unrecognizable line
return self_energies
def load_atomic_network(filename): def load_atomic_network(filename):
...@@ -256,8 +254,7 @@ class Buildins: ...@@ -256,8 +254,7 @@ class Buildins:
self.sae_file = pkg_resources.resource_filename( self.sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat') __name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
self.energy_shifter = EnergyShifter(self.consts.species, self.energy_shifter = load_sae(self.sae_file)
load_sae(self.sae_file))
self.ensemble_size = 8 self.ensemble_size = 8
self.ensemble_prefix = pkg_resources.resource_filename( self.ensemble_prefix = pkg_resources.resource_filename(
......
...@@ -90,7 +90,6 @@ class BatchedANIDataset(Dataset): ...@@ -90,7 +90,6 @@ class BatchedANIDataset(Dataset):
self.properties = properties self.properties = properties
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
device = torch.device('cpu')
# get name of files storing data # get name of files storing data
files = [] files = []
...@@ -111,14 +110,11 @@ class BatchedANIDataset(Dataset): ...@@ -111,14 +110,11 @@ class BatchedANIDataset(Dataset):
for m in anidataloader(f): for m in anidataloader(f):
species = m['species'] species = m['species']
indices = [self.species_indices[i] for i in species] indices = [self.species_indices[i] for i in species]
species = torch.tensor(indices, dtype=torch.long, species = torch.tensor(indices, dtype=torch.long)
device=device) coordinates = torch.from_numpy(m['coordinates'])
coordinates = torch.from_numpy(m['coordinates']) \
.type(dtype).to(device)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
for i in properties: for i in properties:
properties[i].append(torch.from_numpy(m[i]) properties[i].append(torch.from_numpy(m[i]))
.type(dtype).to(device))
species, coordinates = utils.pad_and_batch(species_coordinates) species, coordinates = utils.pad_and_batch(species_coordinates)
for i in properties: for i in properties:
properties[i] = torch.cat(properties[i]) properties[i] = torch.cat(properties[i])
...@@ -126,7 +122,7 @@ class BatchedANIDataset(Dataset): ...@@ -126,7 +122,7 @@ class BatchedANIDataset(Dataset):
# shuffle if required # shuffle if required
conformations = coordinates.shape[0] conformations = coordinates.shape[0]
if shuffle: if shuffle:
indices = torch.randperm(conformations, device=device) indices = torch.randperm(conformations)
species = species.index_select(0, indices) species = species.index_select(0, indices)
coordinates = coordinates.index_select(0, indices) coordinates = coordinates.index_select(0, indices)
for i in properties: for i in properties:
...@@ -137,6 +133,11 @@ class BatchedANIDataset(Dataset): ...@@ -137,6 +133,11 @@ class BatchedANIDataset(Dataset):
species, coordinates, properties = t(species, coordinates, species, coordinates, properties = t(species, coordinates,
properties) properties)
# convert to desired dtype
species = species
coordinates = coordinates.to(dtype)
properties = {k: properties[k].to(dtype) for k in properties}
# split into minibatches, and strip reduncant padding # split into minibatches, and strip reduncant padding
natoms = (species >= 0).to(torch.long).sum(1) natoms = (species >= 0).to(torch.long).sum(1)
batches = [] batches = []
......
...@@ -38,22 +38,21 @@ def strip_redundant_padding(species, coordinates): ...@@ -38,22 +38,21 @@ def strip_redundant_padding(species, coordinates):
class EnergyShifter(torch.nn.Module): class EnergyShifter(torch.nn.Module):
def __init__(self, species, self_energies): def __init__(self, self_energies):
super(EnergyShifter, self).__init__() super(EnergyShifter, self).__init__()
self_energies_tensor = [self_energies[s] for s in species] self_energies = torch.tensor(self_energies, dtype=torch.double)
self.register_buffer('self_energies_tensor', self.register_buffer('self_energies', self_energies)
torch.tensor(self_energies_tensor,
dtype=torch.double))
def sae(self, species): def sae(self, species):
self_energies = self.self_energies_tensor[species] self_energies = self.self_energies[species]
self_energies[species == -1] = 0 self_energies[species == -1] = 0
return self_energies.sum(dim=1) return self_energies.sum(dim=1)
def subtract_from_dataset(self, species, coordinates, properties): def subtract_from_dataset(self, species, coordinates, properties):
dtype = properties['energies'].dtype energies = properties['energies']
device = properties['energies'].device device = energies.device
properties['energies'] -= self.sae(species).to(dtype).to(device) energies = energies.to(torch.double) - self.sae(species).to(device)
properties['energies'] = energies
return species, coordinates, properties return species, coordinates, properties
def forward(self, species_energies): def forward(self, species_energies):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment