"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "83427953f281c315bdd357ebcf94eced6940475c"
Unverified Commit f50cc0b4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

remove output_length (#54)

parent 73e447f0
...@@ -13,7 +13,6 @@ def atomic(): ...@@ -13,7 +13,6 @@ def atomic():
torch.nn.CELU(0.1), torch.nn.CELU(0.1),
torch.nn.Linear(64, 1) torch.nn.Linear(64, 1)
) )
model.output_length = 1
return model return model
......
...@@ -401,13 +401,9 @@ class SortedAEV(AEVComputer): ...@@ -401,13 +401,9 @@ class SortedAEV(AEVComputer):
Tensor of shape (conformations, atoms, pairs, present species, Tensor of shape (conformations, atoms, pairs, present species,
present species) storing the mask for each pair. present species) storing the mask for each pair.
""" """
species_a = self.combinations(species_a, -1) species_a1, species_a2 = self.combinations(species_a, -1)
species_a1, species_a2 = species_a mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
mask_a1 = (species_a1.unsqueeze(-1) ==
present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1)
== present_species)
mask = mask_a1 * mask_a2 mask = mask_a1 * mask_a2
mask_rev = mask.permute(0, 1, 2, 4, 3) mask_rev = mask.permute(0, 1, 2, 4, 3)
mask_a = (mask + mask_rev) > 0 mask_a = (mask + mask_rev) > 0
......
...@@ -52,7 +52,7 @@ class DictMetric(Metric): ...@@ -52,7 +52,7 @@ class DictMetric(Metric):
def MSELoss(key, per_atom=True): def MSELoss(key, per_atom=True):
if per_atom: if per_atom:
return _PerAtomDictLoss(key, torch.nn.MSELoss(reduce=False)) return _PerAtomDictLoss(key, torch.nn.MSELoss(reduction='none'))
else: else:
return DictLoss(key, torch.nn.MSELoss()) return DictLoss(key, torch.nn.MSELoss())
......
...@@ -10,8 +10,6 @@ class ANIModel(BenchmarkedModule): ...@@ -10,8 +10,6 @@ class ANIModel(BenchmarkedModule):
---------- ----------
species : list species : list
Chemical symbol of supported atom species. Chemical symbol of supported atom species.
output_length : int
The length of output vector.
suffixes : sequence suffixes : sequence
Different suffixes denote different models in an ensemble. Different suffixes denote different models in an ensemble.
model_<X><suffix> : nn.Module model_<X><suffix> : nn.Module
...@@ -30,13 +28,12 @@ class ANIModel(BenchmarkedModule): ...@@ -30,13 +28,12 @@ class ANIModel(BenchmarkedModule):
forward : total time for the forward pass forward : total time for the forward pass
""" """
def __init__(self, species, suffixes, reducer, output_length, models, def __init__(self, species, suffixes, reducer, models,
benchmark=False): benchmark=False):
super(ANIModel, self).__init__(benchmark) super(ANIModel, self).__init__(benchmark)
self.species = species self.species = species
self.suffixes = suffixes self.suffixes = suffixes
self.reducer = reducer self.reducer = reducer
self.output_length = output_length
for i in models: for i in models:
setattr(self, i, models[i]) setattr(self, i, models[i])
...@@ -72,6 +69,7 @@ class ANIModel(BenchmarkedModule): ...@@ -72,6 +69,7 @@ class ANIModel(BenchmarkedModule):
for s in species_dedup: for s in species_dedup:
begin = species.index(s) begin = species.index(s)
end = atoms - rev_species.index(s) end = atoms - rev_species.index(s)
part_atoms = end - begin
y = aev[:, begin:end, :].flatten(0, 1) y = aev[:, begin:end, :].flatten(0, 1)
def apply_model(suffix): def apply_model(suffix):
...@@ -80,7 +78,7 @@ class ANIModel(BenchmarkedModule): ...@@ -80,7 +78,7 @@ class ANIModel(BenchmarkedModule):
return model_X(y) return model_X(y)
ys = [apply_model(suffix) for suffix in self.suffixes] ys = [apply_model(suffix) for suffix in self.suffixes]
y = sum(ys) / len(ys) y = sum(ys) / len(ys)
y = y.view(conformations, -1, self.output_length) y = y.view(conformations, part_atoms, -1)
per_species_outputs.append(y) per_species_outputs.append(y)
per_species_outputs = torch.cat(per_species_outputs, dim=1) per_species_outputs = torch.cat(per_species_outputs, dim=1)
......
...@@ -17,22 +17,8 @@ class CustomModel(ANIModel): ...@@ -17,22 +17,8 @@ class CustomModel(ANIModel):
The desired `reducer` attribute. The desired `reducer` attribute.
""" """
suffixes = [''] suffixes = ['']
output_length = None
models = {} models = {}
for i in per_species: for i in per_species:
model_X = per_species[i] models['model_' + i] = per_species[i]
if not hasattr(model_X, 'output_length'):
raise ValueError(
'''atomic neural network must explicitly specify
output length''')
elif output_length is None:
output_length = model_X.output_length
elif output_length != model_X.output_length:
raise ValueError(
'''output length of each atomic neural network must
match''')
super(CustomModel, self).__init__(list(per_species.keys()), suffixes, super(CustomModel, self).__init__(list(per_species.keys()), suffixes,
reducer, output_length, models, reducer, models, benchmark)
benchmark)
for i in per_species:
setattr(self, 'model_' + i, per_species[i])
...@@ -14,8 +14,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -14,8 +14,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
---------- ----------
layers : int layers : int
Number of layers. Number of layers.
output_length : int
The length of output vector
layerN : torch.nn.Linear layerN : torch.nn.Linear
Linear model for each layer. Linear model for each layer.
activation : function activation : function
...@@ -202,7 +200,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -202,7 +200,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
raise ValueError('bad parameter shape') raise ValueError('bad parameter shape')
wfn = os.path.join(dirname, wfn) wfn = os.path.join(dirname, wfn)
bfn = os.path.join(dirname, bfn) bfn = os.path.join(dirname, bfn)
self.output_length = out_size
self._load_param_file(linear, in_size, out_size, wfn, bfn) self._load_param_file(linear, in_size, out_size, wfn, bfn)
def _load_param_file(self, linear, in_size, out_size, wfn, bfn): def _load_param_file(self, linear, in_size, out_size, wfn, bfn):
......
...@@ -43,18 +43,11 @@ class NeuroChemNNP(ANIModel): ...@@ -43,18 +43,11 @@ class NeuroChemNNP(ANIModel):
reducer = torch.sum reducer = torch.sum
models = {} models = {}
output_length = None
for network_dir, suffix in zip(network_dirs, suffixes): for network_dir, suffix in zip(network_dirs, suffixes):
for i in species: for i in species:
filename = os.path.join( filename = os.path.join(
network_dir, 'ANN-{}.nnf'.format(i)) network_dir, 'ANN-{}.nnf'.format(i))
model_X = NeuroChemAtomicNetwork(filename) model_X = NeuroChemAtomicNetwork(filename)
if output_length is None:
output_length = model_X.output_length
elif output_length != model_X.output_length:
raise ValueError(
'''output length of each atomic neural networt
must match''')
models['model_' + i + suffix] = model_X models['model_' + i + suffix] = model_X
super(NeuroChemNNP, self).__init__(species, suffixes, reducer, super(NeuroChemNNP, self).__init__(species, suffixes, reducer,
output_length, models, benchmark) models, benchmark)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment