Unverified Commit f50cc0b4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

remove output_length (#54)

parent 73e447f0
......@@ -13,7 +13,6 @@ def atomic():
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
model.output_length = 1
return model
......
......@@ -401,13 +401,9 @@ class SortedAEV(AEVComputer):
Tensor of shape (conformations, atoms, pairs, present species,
present species) storing the mask for each pair.
"""
species_a = self.combinations(species_a, -1)
species_a1, species_a2 = species_a
mask_a1 = (species_a1.unsqueeze(-1) ==
present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1)
== present_species)
species_a1, species_a2 = self.combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
mask = mask_a1 * mask_a2
mask_rev = mask.permute(0, 1, 2, 4, 3)
mask_a = (mask + mask_rev) > 0
......
......@@ -52,7 +52,7 @@ class DictMetric(Metric):
def MSELoss(key, per_atom=True):
if per_atom:
return _PerAtomDictLoss(key, torch.nn.MSELoss(reduce=False))
return _PerAtomDictLoss(key, torch.nn.MSELoss(reduction='none'))
else:
return DictLoss(key, torch.nn.MSELoss())
......
......@@ -10,8 +10,6 @@ class ANIModel(BenchmarkedModule):
----------
species : list
Chemical symbol of supported atom species.
output_length : int
The length of output vector.
suffixes : sequence
Different suffixes denote different models in an ensemble.
model_<X><suffix> : nn.Module
......@@ -30,13 +28,12 @@ class ANIModel(BenchmarkedModule):
forward : total time for the forward pass
"""
def __init__(self, species, suffixes, reducer, output_length, models,
def __init__(self, species, suffixes, reducer, models,
benchmark=False):
super(ANIModel, self).__init__(benchmark)
self.species = species
self.suffixes = suffixes
self.reducer = reducer
self.output_length = output_length
for i in models:
setattr(self, i, models[i])
......@@ -72,6 +69,7 @@ class ANIModel(BenchmarkedModule):
for s in species_dedup:
begin = species.index(s)
end = atoms - rev_species.index(s)
part_atoms = end - begin
y = aev[:, begin:end, :].flatten(0, 1)
def apply_model(suffix):
......@@ -80,7 +78,7 @@ class ANIModel(BenchmarkedModule):
return model_X(y)
ys = [apply_model(suffix) for suffix in self.suffixes]
y = sum(ys) / len(ys)
y = y.view(conformations, -1, self.output_length)
y = y.view(conformations, part_atoms, -1)
per_species_outputs.append(y)
per_species_outputs = torch.cat(per_species_outputs, dim=1)
......
......@@ -17,22 +17,8 @@ class CustomModel(ANIModel):
The desired `reducer` attribute.
"""
suffixes = ['']
output_length = None
models = {}
for i in per_species:
model_X = per_species[i]
if not hasattr(model_X, 'output_length'):
raise ValueError(
'''atomic neural network must explicitly specify
output length''')
elif output_length is None:
output_length = model_X.output_length
elif output_length != model_X.output_length:
raise ValueError(
'''output length of each atomic neural network must
match''')
models['model_' + i] = per_species[i]
super(CustomModel, self).__init__(list(per_species.keys()), suffixes,
reducer, output_length, models,
benchmark)
for i in per_species:
setattr(self, 'model_' + i, per_species[i])
reducer, models, benchmark)
......@@ -14,8 +14,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
----------
layers : int
Number of layers.
output_length : int
The length of output vector
layerN : torch.nn.Linear
Linear model for each layer.
activation : function
......@@ -202,7 +200,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
raise ValueError('bad parameter shape')
wfn = os.path.join(dirname, wfn)
bfn = os.path.join(dirname, bfn)
self.output_length = out_size
self._load_param_file(linear, in_size, out_size, wfn, bfn)
def _load_param_file(self, linear, in_size, out_size, wfn, bfn):
......
......@@ -43,18 +43,11 @@ class NeuroChemNNP(ANIModel):
reducer = torch.sum
models = {}
output_length = None
for network_dir, suffix in zip(network_dirs, suffixes):
for i in species:
filename = os.path.join(
network_dir, 'ANN-{}.nnf'.format(i))
model_X = NeuroChemAtomicNetwork(filename)
if output_length is None:
output_length = model_X.output_length
elif output_length != model_X.output_length:
raise ValueError(
'''output length of each atomic neural networt
must match''')
models['model_' + i + suffix] = model_X
super(NeuroChemNNP, self).__init__(species, suffixes, reducer,
output_length, models, benchmark)
models, benchmark)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment