"vscode:/vscode.git/clone" did not exist on "c2e61ce10b4704ce24133900731787c0bcc5847e"
Commit 7cdd405c authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Don't hard code ensemble size (#352)

* Don't hard code ensemble size

* flake8

* fix
parent b9422498
......@@ -55,43 +55,21 @@ class ANIModel(torch.nn.Module):
class Ensemble(torch.nn.Module):
"""Compute the average output of an ensemble of modules."""
# FIXME: due to PyTorch bug, we have to hard code the
# ensemble size to 8.
# def __init__(self, modules):
# super(Ensemble, self).__init__()
# self.modules_list = torch.nn.ModuleList(modules)
# def forward(self, species_input):
# # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
# outputs = [x(species_input)[1] for x in self.modules_list]
# species, _ = species_input
# return species, sum(outputs) / len(outputs)
def __init__(self, modules):
super(Ensemble, self).__init__()
assert len(modules) == 8
self.model0 = modules[0]
self.model1 = modules[1]
self.model2 = modules[2]
self.model3 = modules[3]
self.model4 = modules[4]
self.model5 = modules[5]
self.model6 = modules[6]
self.model7 = modules[7]
def __getitem__(self, i):
return [self.model0, self.model1, self.model2, self.model3,
self.model4, self.model5, self.model6, self.model7][i]
self.modules_list = torch.nn.ModuleList(modules)
self.size = len(self.modules_list)
def forward(self, species_input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
sum_ = 0
for x in self.modules_list:
sum_ += x(species_input)[1]
species, _ = species_input
sum_ = self.model0(species_input)[1] + self.model1(species_input)[1] \
+ self.model2(species_input)[1] + self.model3(species_input)[1] \
+ self.model4(species_input)[1] + self.model5(species_input)[1] \
+ self.model6(species_input)[1] + self.model7(species_input)[1]
return species, sum_ / 8.0
return species, sum_ / self.size
def __getitem__(self, i):
return self.modules_list[i]
class Sequential(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment