Commit 7cdd405c authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Don't hard code ensemble size (#352)

* Don't hard code ensemble size

* flake8

* fix
parent b9422498
...@@ -55,43 +55,21 @@ class ANIModel(torch.nn.Module): ...@@ -55,43 +55,21 @@ class ANIModel(torch.nn.Module):
class Ensemble(torch.nn.Module): class Ensemble(torch.nn.Module):
"""Compute the average output of an ensemble of modules.""" """Compute the average output of an ensemble of modules."""
# FIXME: due to PyTorch bug, we have to hard code the
# ensemble size to 8.
# def __init__(self, modules):
# super(Ensemble, self).__init__()
# self.modules_list = torch.nn.ModuleList(modules)
# def forward(self, species_input):
# # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
# outputs = [x(species_input)[1] for x in self.modules_list]
# species, _ = species_input
# return species, sum(outputs) / len(outputs)
def __init__(self, modules): def __init__(self, modules):
super(Ensemble, self).__init__() super(Ensemble, self).__init__()
assert len(modules) == 8 self.modules_list = torch.nn.ModuleList(modules)
self.model0 = modules[0] self.size = len(self.modules_list)
self.model1 = modules[1]
self.model2 = modules[2]
self.model3 = modules[3]
self.model4 = modules[4]
self.model5 = modules[5]
self.model6 = modules[6]
self.model7 = modules[7]
def __getitem__(self, i):
return [self.model0, self.model1, self.model2, self.model3,
self.model4, self.model5, self.model6, self.model7][i]
def forward(self, species_input): def forward(self, species_input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
sum_ = 0
for x in self.modules_list:
sum_ += x(species_input)[1]
species, _ = species_input species, _ = species_input
sum_ = self.model0(species_input)[1] + self.model1(species_input)[1] \ return species, sum_ / self.size
+ self.model2(species_input)[1] + self.model3(species_input)[1] \
+ self.model4(species_input)[1] + self.model5(species_input)[1] \ def __getitem__(self, i):
+ self.model6(species_input)[1] + self.model7(species_input)[1] return self.modules_list[i]
return species, sum_ / 8.0
class Sequential(torch.nn.Module): class Sequential(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment