Commit b9422498 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Use enumerate to clean up ANIModel (#358)

parent 669f70a1
......@@ -41,10 +41,9 @@ class ANIModel(torch.nn.Module):
output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype, device=species.device)
i = 0
for m in self.module_list:
for i, m in enumerate(self.module_list):
mask = (species_ == i)
i += 1
midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
input_ = aev.index_select(0, midx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment