Unverified Commit eb090700 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

simplify handling of species to avoid unnecessary arguments (#74)

parent ce21d224
...@@ -3,6 +3,10 @@ import torchani ...@@ -3,6 +3,10 @@ import torchani
import os import os
consts = torchani.buildins.consts
aev_computer = torchani.buildins.aev_computer
def atomic(): def atomic():
model = torch.nn.Sequential( model = torch.nn.Sequential(
torch.nn.Linear(384, 128), torch.nn.Linear(384, 128),
...@@ -17,13 +21,7 @@ def atomic(): ...@@ -17,13 +21,7 @@ def atomic():
def get_or_create_model(filename, device=torch.device('cpu')): def get_or_create_model(filename, device=torch.device('cpu')):
aev_computer = torchani.buildins.aev_computer model = torchani.ANIModel([atomic() for _ in range(4)])
model = torchani.ANIModel([
('C', atomic()),
('H', atomic()),
('N', atomic()),
('O', atomic()),
])
class Flatten(torch.nn.Module): class Flatten(torch.nn.Module):
......
...@@ -54,7 +54,7 @@ start = timeit.default_timer() ...@@ -54,7 +54,7 @@ start = timeit.default_timer()
nnp = model.get_or_create_model(parser.model_checkpoint, device=device) nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
shift_energy = torchani.buildins.energy_shifter shift_energy = torchani.buildins.energy_shifter
training, validation, testing = torchani.training.load_or_create( training, validation, testing = torchani.training.load_or_create(
parser.dataset_checkpoint, parser.batch_size, nnp[0].species, parser.dataset_checkpoint, parser.batch_size, model.consts.species,
parser.dataset_path, device=device, parser.dataset_path, device=device,
transform=[shift_energy.subtract_from_dataset]) transform=[shift_energy.subtract_from_dataset])
container = torchani.training.Container({'energies': nnp}) container = torchani.training.Container({'energies': nnp})
......
...@@ -24,7 +24,8 @@ device = torch.device(parser.device) ...@@ -24,7 +24,8 @@ device = torch.device(parser.device)
nnp = model.get_or_create_model('/tmp/model.pt', device=device) nnp = model.get_or_create_model('/tmp/model.pt', device=device)
shift_energy = torchani.buildins.energy_shifter shift_energy = torchani.buildins.energy_shifter
dataset = torchani.training.BatchedANIDataset( dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, nnp[0].species, parser.batch_size, device=device, parser.dataset_path, model.consts.species,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset]) transform=[shift_energy.subtract_from_dataset])
container = torchani.training.Container({'energies': nnp}) container = torchani.training.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters()) optimizer = torch.optim.Adam(nnp.parameters())
......
...@@ -42,11 +42,12 @@ class AEVComputer(torch.nn.Module): ...@@ -42,11 +42,12 @@ class AEVComputer(torch.nn.Module):
The name of the file that stores constant. The name of the file that stores constant.
Rcr, Rca, EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor Rcr, Rca, EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants. Tensor storing constants.
species : list(str) num_species : int
Chemical symbols of supported atom types Number of supported atom types
""" """
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, species): def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
num_species):
super(AEVComputer, self).__init__() super(AEVComputer, self).__init__()
self.register_buffer('Rcr', Rcr) self.register_buffer('Rcr', Rcr)
self.register_buffer('Rca', Rca) self.register_buffer('Rca', Rca)
...@@ -60,7 +61,7 @@ class AEVComputer(torch.nn.Module): ...@@ -60,7 +61,7 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('ShfA', ShfA.view(1, 1, -1, 1)) self.register_buffer('ShfA', ShfA.view(1, 1, -1, 1))
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1)) self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))
self.species = species self.num_species = num_species
def radial_sublength(self): def radial_sublength(self):
"""Returns the length of radial subaev of a single species""" """Returns the length of radial subaev of a single species"""
...@@ -68,7 +69,7 @@ class AEVComputer(torch.nn.Module): ...@@ -68,7 +69,7 @@ class AEVComputer(torch.nn.Module):
def radial_length(self): def radial_length(self):
"""Returns the length of full radial aev""" """Returns the length of full radial aev"""
return len(self.species) * self.radial_sublength() return self.num_species * self.radial_sublength()
def angular_sublength(self): def angular_sublength(self):
"""Returns the length of angular subaev of a single species""" """Returns the length of angular subaev of a single species"""
...@@ -77,8 +78,8 @@ class AEVComputer(torch.nn.Module): ...@@ -77,8 +78,8 @@ class AEVComputer(torch.nn.Module):
def angular_length(self): def angular_length(self):
"""Returns the length of full angular aev""" """Returns the length of full angular aev"""
species = len(self.species) s = self.num_species
return int((species * (species + 1)) / 2) * self.angular_sublength() return (s * (s + 1)) // 2 * self.angular_sublength()
def aev_length(self): def aev_length(self):
"""Returns the length of full aev""" """Returns the length of full aev"""
...@@ -266,7 +267,7 @@ class AEVComputer(torch.nn.Module): ...@@ -266,7 +267,7 @@ class AEVComputer(torch.nn.Module):
"""Tensor of shape (conformations, atoms, neighbors) storing species """Tensor of shape (conformations, atoms, neighbors) storing species
of neighbors.""" of neighbors."""
mask_r = (species_r.unsqueeze(-1) == mask_r = (species_r.unsqueeze(-1) ==
torch.arange(len(self.species), device=self.EtaR.device)) torch.arange(self.num_species, device=self.EtaR.device))
return mask_r return mask_r
def compute_mask_a(self, species, indices_a, present_species): def compute_mask_a(self, species, indices_a, present_species):
...@@ -348,7 +349,7 @@ class AEVComputer(torch.nn.Module): ...@@ -348,7 +349,7 @@ class AEVComputer(torch.nn.Module):
conformations, atoms, self.angular_sublength(), conformations, atoms, self.angular_sublength(),
dtype=self.EtaR.dtype, device=self.EtaR.device) dtype=self.EtaR.dtype, device=self.EtaR.device)
for s1, s2 in itertools.combinations_with_replacement( for s1, s2 in itertools.combinations_with_replacement(
range(len(self.species)), 2): range(self.num_species), 2):
if s1 in rev_indices and s2 in rev_indices: if s1 in rev_indices and s2 in rev_indices:
i1 = rev_indices[s1] i1 = rev_indices[s1]
i2 = rev_indices[s2] i2 = rev_indices[s2]
......
...@@ -2,15 +2,14 @@ import torch ...@@ -2,15 +2,14 @@ import torch
from . import utils from . import utils
class ANIModel(torch.nn.Module): class ANIModel(torch.nn.ModuleList):
def __init__(self, models, reducer=torch.sum, padding_fill=0): def __init__(self, modules, reducer=torch.sum, padding_fill=0):
""" """
Parameters Parameters
---------- ----------
models : (str, torch.nn.Module) modules : seq(torch.nn.Module)
Models for all species. This must be a mapping where the key is Modules for all species.
atomic symbol and the value is a module.
reducer : function reducer : function
Function of (input, dim)->output that reduce the input tensor along Function of (input, dim)->output that reduce the input tensor along
the given dimension to get an output tensor. This function will be the given dimension to get an output tensor. This function will be
...@@ -20,12 +19,9 @@ class ANIModel(torch.nn.Module): ...@@ -20,12 +19,9 @@ class ANIModel(torch.nn.Module):
padding_fill : float padding_fill : float
Default value used to fill padding atoms Default value used to fill padding atoms
""" """
super(ANIModel, self).__init__() super(ANIModel, self).__init__(modules)
self.species = [s for s, _ in models]
self.reducer = reducer self.reducer = reducer
self.padding_fill = padding_fill self.padding_fill = padding_fill
for s, m in models:
setattr(self, 'model_' + s, m)
def forward(self, species_aev): def forward(self, species_aev):
"""Compute output from aev """Compute output from aev
...@@ -56,11 +52,9 @@ class ANIModel(torch.nn.Module): ...@@ -56,11 +52,9 @@ class ANIModel(torch.nn.Module):
output = torch.full_like(species_, self.padding_fill, output = torch.full_like(species_, self.padding_fill,
dtype=aev.dtype) dtype=aev.dtype)
for i in present_species: for i in present_species:
s = self.species[i]
model_X = getattr(self, 'model_' + s)
mask = (species_ == i) mask = (species_ == i)
input = aev.index_select(0, mask.nonzero().squeeze()) input = aev.index_select(0, mask.nonzero().squeeze())
output[mask] = model_X(input).squeeze() output[mask] = self[i](input).squeeze()
output = output.view_as(species) output = output.view_as(species)
return species, self.reducer(output, dim=1) return species, self.reducer(output, dim=1)
......
...@@ -33,6 +33,7 @@ class Constants(Mapping): ...@@ -33,6 +33,7 @@ class Constants(Mapping):
self.species = value self.species = value
except Exception: except Exception:
raise ValueError('unable to parse const file') raise ValueError('unable to parse const file')
self.num_species = len(self.species)
self.rev_species = {} self.rev_species = {}
for i in range(len(self.species)): for i in range(len(self.species)):
s = self.species[i] s = self.species[i]
...@@ -47,7 +48,7 @@ class Constants(Mapping): ...@@ -47,7 +48,7 @@ class Constants(Mapping):
yield 'Zeta' yield 'Zeta'
yield 'ShfA' yield 'ShfA'
yield 'ShfZ' yield 'ShfZ'
yield 'species' yield 'num_species'
def __len__(self): def __len__(self):
return 8 return 8
...@@ -232,7 +233,7 @@ def load_model(species, from_): ...@@ -232,7 +233,7 @@ def load_model(species, from_):
models = [] models = []
for i in species: for i in species:
filename = os.path.join(from_, 'ANN-{}.nnf'.format(i)) filename = os.path.join(from_, 'ANN-{}.nnf'.format(i))
models.append((i, load_atomic_network(filename))) models.append(load_atomic_network(filename))
return ANIModel(models) return ANIModel(models)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment