Unverified Commit eb090700 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

simplify handling of species to avoid unnecessary arguments (#74)

parent ce21d224
......@@ -3,6 +3,10 @@ import torchani
import os
consts = torchani.buildins.consts
aev_computer = torchani.buildins.aev_computer
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
......@@ -17,13 +21,7 @@ def atomic():
def get_or_create_model(filename, device=torch.device('cpu')):
aev_computer = torchani.buildins.aev_computer
model = torchani.ANIModel([
('C', atomic()),
('H', atomic()),
('N', atomic()),
('O', atomic()),
])
model = torchani.ANIModel([atomic() for _ in range(4)])
class Flatten(torch.nn.Module):
......
......@@ -54,7 +54,7 @@ start = timeit.default_timer()
nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
shift_energy = torchani.buildins.energy_shifter
training, validation, testing = torchani.training.load_or_create(
parser.dataset_checkpoint, parser.batch_size, nnp[0].species,
parser.dataset_checkpoint, parser.batch_size, model.consts.species,
parser.dataset_path, device=device,
transform=[shift_energy.subtract_from_dataset])
container = torchani.training.Container({'energies': nnp})
......
......@@ -24,7 +24,8 @@ device = torch.device(parser.device)
nnp = model.get_or_create_model('/tmp/model.pt', device=device)
shift_energy = torchani.buildins.energy_shifter
dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, nnp[0].species, parser.batch_size, device=device,
parser.dataset_path, model.consts.species,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
container = torchani.training.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
......
......@@ -42,11 +42,12 @@ class AEVComputer(torch.nn.Module):
The name of the file that stores constant.
Rcr, Rca, EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
species : list(str)
Chemical symbols of supported atom types
num_species : int
Number of supported atom types
"""
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, species):
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
num_species):
super(AEVComputer, self).__init__()
self.register_buffer('Rcr', Rcr)
self.register_buffer('Rca', Rca)
......@@ -60,7 +61,7 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('ShfA', ShfA.view(1, 1, -1, 1))
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))
self.species = species
self.num_species = num_species
def radial_sublength(self):
"""Returns the length of radial subaev of a single species"""
......@@ -68,7 +69,7 @@ class AEVComputer(torch.nn.Module):
def radial_length(self):
"""Returns the length of full radial aev"""
return len(self.species) * self.radial_sublength()
return self.num_species * self.radial_sublength()
def angular_sublength(self):
"""Returns the length of angular subaev of a single species"""
......@@ -77,8 +78,8 @@ class AEVComputer(torch.nn.Module):
def angular_length(self):
"""Returns the length of full angular aev"""
species = len(self.species)
return int((species * (species + 1)) / 2) * self.angular_sublength()
s = self.num_species
return (s * (s + 1)) // 2 * self.angular_sublength()
def aev_length(self):
"""Returns the length of full aev"""
......@@ -266,7 +267,7 @@ class AEVComputer(torch.nn.Module):
"""Tensor of shape (conformations, atoms, neighbors) storing species
of neighbors."""
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(len(self.species), device=self.EtaR.device))
torch.arange(self.num_species, device=self.EtaR.device))
return mask_r
def compute_mask_a(self, species, indices_a, present_species):
......@@ -348,7 +349,7 @@ class AEVComputer(torch.nn.Module):
conformations, atoms, self.angular_sublength(),
dtype=self.EtaR.dtype, device=self.EtaR.device)
for s1, s2 in itertools.combinations_with_replacement(
range(len(self.species)), 2):
range(self.num_species), 2):
if s1 in rev_indices and s2 in rev_indices:
i1 = rev_indices[s1]
i2 = rev_indices[s2]
......
......@@ -2,15 +2,14 @@ import torch
from . import utils
class ANIModel(torch.nn.Module):
class ANIModel(torch.nn.ModuleList):
def __init__(self, models, reducer=torch.sum, padding_fill=0):
def __init__(self, modules, reducer=torch.sum, padding_fill=0):
"""
Parameters
----------
models : (str, torch.nn.Module)
Models for all species. This must be a mapping where the key is
atomic symbol and the value is a module.
modules : seq(torch.nn.Module)
Modules for all species.
reducer : function
Function of (input, dim)->output that reduce the input tensor along
the given dimension to get an output tensor. This function will be
......@@ -20,12 +19,9 @@ class ANIModel(torch.nn.Module):
padding_fill : float
Default value used to fill padding atoms
"""
super(ANIModel, self).__init__()
self.species = [s for s, _ in models]
super(ANIModel, self).__init__(modules)
self.reducer = reducer
self.padding_fill = padding_fill
for s, m in models:
setattr(self, 'model_' + s, m)
def forward(self, species_aev):
"""Compute output from aev
......@@ -56,11 +52,9 @@ class ANIModel(torch.nn.Module):
output = torch.full_like(species_, self.padding_fill,
dtype=aev.dtype)
for i in present_species:
s = self.species[i]
model_X = getattr(self, 'model_' + s)
mask = (species_ == i)
input = aev.index_select(0, mask.nonzero().squeeze())
output[mask] = model_X(input).squeeze()
output[mask] = self[i](input).squeeze()
output = output.view_as(species)
return species, self.reducer(output, dim=1)
......
......@@ -33,6 +33,7 @@ class Constants(Mapping):
self.species = value
except Exception:
raise ValueError('unable to parse const file')
self.num_species = len(self.species)
self.rev_species = {}
for i in range(len(self.species)):
s = self.species[i]
......@@ -47,7 +48,7 @@ class Constants(Mapping):
yield 'Zeta'
yield 'ShfA'
yield 'ShfZ'
yield 'species'
yield 'num_species'
def __len__(self):
return 8
......@@ -232,7 +233,7 @@ def load_model(species, from_):
models = []
for i in species:
filename = os.path.join(from_, 'ANN-{}.nnf'.format(i))
models.append((i, load_atomic_network(filename)))
models.append(load_atomic_network(filename))
return ANIModel(models)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment