"src/vscode:/vscode.git/clone" did not exist on "12fd0736dcc51f77c52130ab10177d0c1d5a29d9"
Commit be58b3fe authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Cleanup some junks inside ANIModel introduced due to JIT (#362)

parent 86500df0
......@@ -16,19 +16,11 @@ class ANIModel(torch.nn.Module):
:attr:`modules`, which means, for example ``modules[i]`` must be
the module for atom type ``i``. Different atom types can share a
module by putting the same reference in :attr:`modules`.
padding_fill (float): The value to fill output of padding atoms.
Padding values will participate in reducing, so this value should
be appropriately chosen so that it has no effect on the result. For
example, if the reducer is :func:`torch.sum`, then
:attr:`padding_fill` should be 0, and if the reducer is
:func:`torch.min`, then :attr:`padding_fill` should be
:obj:`math.inf`.
"""
def __init__(self, modules, padding_fill=0):
def __init__(self, modules):
super(ANIModel, self).__init__()
self.module_list = torch.nn.ModuleList(modules)
self.padding_fill = padding_fill
def __getitem__(self, i):
return self.module_list[i]
......@@ -39,8 +31,7 @@ class ANIModel(torch.nn.Module):
species_ = species.flatten()
aev = aev.flatten(0, 1)
output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype, device=species.device)
output = aev.new_zeros(species_.shape)
for i, m in enumerate(self.module_list):
mask = (species_ == i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment