Commit b9422498 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Use enumerate to clean up ANIModel (#358)

parent 669f70a1
...@@ -41,10 +41,9 @@ class ANIModel(torch.nn.Module): ...@@ -41,10 +41,9 @@ class ANIModel(torch.nn.Module):
output = torch.full(species_.shape, self.padding_fill, output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype, device=species.device) dtype=aev.dtype, device=species.device)
i = 0
for m in self.module_list: for i, m in enumerate(self.module_list):
mask = (species_ == i) mask = (species_ == i)
i += 1
midx = mask.nonzero().flatten() midx = mask.nonzero().flatten()
if midx.shape[0] > 0: if midx.shape[0] > 0:
input_ = aev.index_select(0, midx) input_ = aev.index_select(0, midx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment