Commit e41c2d93 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Subclass ModuleList to simplify code (#385)

parent 93372134
...@@ -8,7 +8,7 @@ class SpeciesEnergies(NamedTuple): ...@@ -8,7 +8,7 @@ class SpeciesEnergies(NamedTuple):
energies: Tensor energies: Tensor
class ANIModel(torch.nn.Module): class ANIModel(torch.nn.ModuleList):
"""ANI model that compute energies from species and AEVs. """ANI model that compute energies from species and AEVs.
Different atom types might have different modules, when computing Different atom types might have different modules, when computing
...@@ -27,11 +27,7 @@ class ANIModel(torch.nn.Module): ...@@ -27,11 +27,7 @@ class ANIModel(torch.nn.Module):
""" """
def __init__(self, modules): def __init__(self, modules):
super(ANIModel, self).__init__() super(ANIModel, self).__init__(modules)
self.module_list = torch.nn.ModuleList(modules)
def __getitem__(self, i):
return self.module_list[i]
def forward(self, species_aev: Tuple[Tensor, Tensor], def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
...@@ -44,7 +40,7 @@ class ANIModel(torch.nn.Module): ...@@ -44,7 +40,7 @@ class ANIModel(torch.nn.Module):
output = aev.new_zeros(species_.shape) output = aev.new_zeros(species_.shape)
for i, m in enumerate(self.module_list): for i, m in enumerate(self):
mask = (species_ == i) mask = (species_ == i)
midx = mask.nonzero().flatten() midx = mask.nonzero().flatten()
if midx.shape[0] > 0: if midx.shape[0] > 0:
...@@ -54,13 +50,12 @@ class ANIModel(torch.nn.Module): ...@@ -54,13 +50,12 @@ class ANIModel(torch.nn.Module):
return SpeciesEnergies(species, torch.sum(output, dim=1)) return SpeciesEnergies(species, torch.sum(output, dim=1))
class Ensemble(torch.nn.Module): class Ensemble(torch.nn.ModuleList):
"""Compute the average output of an ensemble of modules.""" """Compute the average output of an ensemble of modules."""
def __init__(self, modules): def __init__(self, modules):
super(Ensemble, self).__init__() super(Ensemble, self).__init__(modules)
self.modules_list = torch.nn.ModuleList(modules) self.size = len(modules)
self.size = len(self.modules_list)
def forward(self, species_input: Tuple[Tensor, Tensor], def forward(self, species_input: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
...@@ -68,26 +63,22 @@ class Ensemble(torch.nn.Module): ...@@ -68,26 +63,22 @@ class Ensemble(torch.nn.Module):
assert cell is None assert cell is None
assert pbc is None assert pbc is None
sum_ = 0 sum_ = 0
for x in self.modules_list: for x in self:
sum_ += x(species_input)[1] sum_ += x(species_input)[1]
species, _ = species_input species, _ = species_input
return SpeciesEnergies(species, sum_ / self.size) return SpeciesEnergies(species, sum_ / self.size)
def __getitem__(self, i):
return self.modules_list[i]
class Sequential(torch.nn.Module): class Sequential(torch.nn.ModuleList):
"""Modified Sequential module that accept Tuple type as input""" """Modified Sequential module that accept Tuple type as input"""
def __init__(self, *modules): def __init__(self, *modules):
super(Sequential, self).__init__() super(Sequential, self).__init__(modules)
self.modules_list = torch.nn.ModuleList(modules)
def forward(self, input_: Tuple[Tensor, Tensor], def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None): pbc: Optional[Tensor] = None):
for module in self.modules_list: for module in self:
input_ = module(input_, cell=cell, pbc=pbc) input_ = module(input_, cell=cell, pbc=pbc)
cell = None cell = None
pbc = None pbc = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment