Unverified Commit c7f549f4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Improve according to codefactor.io (#130)

parent 9719edb1
# <img src=https://raw.githubusercontent.com/aiqm/torchani/master/logo1.png width=180/> Accurate Neural Network Potential on PyTorch # <img src=https://raw.githubusercontent.com/aiqm/torchani/master/logo1.png width=180/> Accurate Neural Network Potential on PyTorch
[![Codefresh build status]( https://g.codefresh.io/api/badges/pipeline/zasdfgbnm/aiqm%2Ftorchani%2Ftorchani?branch=master&type=cf-1)]( https://g.codefresh.io/repositories/aiqm/torchani/builds?filter=trigger:build;branch:master;service:5babc52a8a90dc40a407b05f~torchani) [![Codefresh build status]( https://g.codefresh.io/api/badges/pipeline/zasdfgbnm/aiqm%2Ftorchani%2Ftorchani?branch=master&type=cf-1)]( https://g.codefresh.io/repositories/aiqm/torchani/builds?filter=trigger:build;branch:master;service:5babc52a8a90dc40a407b05f~torchani)
[![CodeFactor](https://www.codefactor.io/repository/github/aiqm/torchani/badge/master)](https://www.codefactor.io/repository/github/aiqm/torchani/overview/master)
[![codecov](https://codecov.io/gh/aiqm/torchani/branch/master/graph/badge.svg)](https://codecov.io/gh/aiqm/torchani)
TorchANI is a pytorch implementation of ANI. It is currently under alpha release, which means, the API is not stable yet. If you find a bug of TorchANI, or have some feature request, feel free to open an issue on GitHub, or send us a pull request. TorchANI is a pytorch implementation of ANI. It is currently under alpha release, which means, the API is not stable yet. If you find a bug of TorchANI, or have some feature request, feel free to open an issue on GitHub, or send us a pull request.
......
...@@ -16,10 +16,10 @@ def default_neighborlist(species, coordinates, cutoff): ...@@ -16,10 +16,10 @@ def default_neighborlist(species, coordinates, cutoff):
"""Default neighborlist computer""" """Default neighborlist computer"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1) vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors""" # vec has hape (conformations, atoms, atoms, 3) storing Rij vectors
distances = vec.norm(2, -1) distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances""" # distances has shape (conformations, atoms, atoms) storing Rij distances
padding_mask = (species == -1).unsqueeze(1) padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf) distances = distances.masked_fill(padding_mask, math.inf)
...@@ -267,7 +267,8 @@ class AEVComputer(torch.nn.Module): ...@@ -267,7 +267,8 @@ class AEVComputer(torch.nn.Module):
radial_terms.unsqueeze(-2) * radial_terms.unsqueeze(-2) *
mask_r.unsqueeze(-1).type(radial_terms.dtype) mask_r.unsqueeze(-1).type(radial_terms.dtype)
).sum(-3) ).sum(-3)
"""shape (conformations, atoms, present species, radial_length)""" # present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs = present_radial_aevs.flatten(start_dim=2) radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev # assemble angular subaev
......
...@@ -38,7 +38,7 @@ def split_batch(natoms, species, coordinates): ...@@ -38,7 +38,7 @@ def split_batch(natoms, species, coordinates):
natoms = natoms.tolist() natoms = natoms.tolist()
counts = [] counts = []
for i in natoms: for i in natoms:
if len(counts) == 0: if not counts:
counts.append([i, 1]) counts.append([i, 1])
continue continue
if i == counts[-1][0]: if i == counts[-1][0]:
......
...@@ -91,7 +91,6 @@ def MSELoss(key, per_atom=True): ...@@ -91,7 +91,6 @@ def MSELoss(key, per_atom=True):
"""Create MSE loss on the specified key.""" """Create MSE loss on the specified key."""
if per_atom: if per_atom:
return PerAtomDictLoss(key, torch.nn.MSELoss(reduction='none')) return PerAtomDictLoss(key, torch.nn.MSELoss(reduction='none'))
else:
return DictLoss(key, torch.nn.MSELoss()) return DictLoss(key, torch.nn.MSELoss())
......
...@@ -526,10 +526,9 @@ class Trainer: ...@@ -526,10 +526,9 @@ class Trainer:
# There is no plan to support the "L2" settings in # There is no plan to support the "L2" settings in
# input file before AdamW get merged into pytorch. # input file before AdamW get merged into pytorch.
raise NotImplementedError('L2 not supported yet') raise NotImplementedError('L2 not supported yet')
l2reg.append((0.5 * layer['l2valu'], module))
del layer['l2norm'] del layer['l2norm']
del layer['l2valu'] del layer['l2valu']
if len(layer) > 0: if layer:
raise ValueError('unrecognized parameter in layer setup') raise ValueError('unrecognized parameter in layer setup')
i = o i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules) atomic_nets[atom_type] = torch.nn.Sequential(*modules)
...@@ -549,7 +548,7 @@ class Trainer: ...@@ -549,7 +548,7 @@ class Trainer:
MSELoss('energies'), MSELoss('energies'),
lambda x: 0.5 * (torch.exp(2 * x) - 1) + l2()) lambda x: 0.5 * (torch.exp(2 * x) - 1) + l2())
if len(params) > 0: if params:
raise ValueError('unrecognized parameter') raise ValueError('unrecognized parameter')
self.global_epoch = 0 self.global_epoch = 0
......
...@@ -161,8 +161,7 @@ class ChemicalSymbolsToInts: ...@@ -161,8 +161,7 @@ class ChemicalSymbolsToInts:
def __init__(self, all_species): def __init__(self, all_species):
self.rev_species = {} self.rev_species = {}
for i in range(len(all_species)): for i, s in enumerate(all_species):
s = all_species[i]
self.rev_species[s] = i self.rev_species[s] = i
def __call__(self, species): def __call__(self, species):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment