Unverified Commit c7f549f4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Improve according to codefactor.io (#130)

parent 9719edb1
# <img src=https://raw.githubusercontent.com/aiqm/torchani/master/logo1.png width=180/> Accurate Neural Network Potential on PyTorch
[![Codefresh build status]( https://g.codefresh.io/api/badges/pipeline/zasdfgbnm/aiqm%2Ftorchani%2Ftorchani?branch=master&type=cf-1)]( https://g.codefresh.io/repositories/aiqm/torchani/builds?filter=trigger:build;branch:master;service:5babc52a8a90dc40a407b05f~torchani)
[![CodeFactor](https://www.codefactor.io/repository/github/aiqm/torchani/badge/master)](https://www.codefactor.io/repository/github/aiqm/torchani/overview/master)
[![codecov](https://codecov.io/gh/aiqm/torchani/branch/master/graph/badge.svg)](https://codecov.io/gh/aiqm/torchani)
TorchANI is a pytorch implementation of ANI. It is currently under alpha release, which means, the API is not stable yet. If you find a bug of TorchANI, or have some feature request, feel free to open an issue on GitHub, or send us a pull request.
......
......@@ -16,10 +16,10 @@ def default_neighborlist(species, coordinates, cutoff):
"""Default neighborlist computer"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors"""
# vec has hape (conformations, atoms, atoms, 3) storing Rij vectors
distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances"""
# distances has shape (conformations, atoms, atoms) storing Rij distances
padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)
......@@ -267,7 +267,8 @@ class AEVComputer(torch.nn.Module):
radial_terms.unsqueeze(-2) *
mask_r.unsqueeze(-1).type(radial_terms.dtype)
).sum(-3)
"""shape (conformations, atoms, present species, radial_length)"""
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev
......
......@@ -38,7 +38,7 @@ def split_batch(natoms, species, coordinates):
natoms = natoms.tolist()
counts = []
for i in natoms:
if len(counts) == 0:
if not counts:
counts.append([i, 1])
continue
if i == counts[-1][0]:
......
......@@ -91,8 +91,7 @@ def MSELoss(key, per_atom=True):
"""Create MSE loss on the specified key."""
if per_atom:
return PerAtomDictLoss(key, torch.nn.MSELoss(reduction='none'))
else:
return DictLoss(key, torch.nn.MSELoss())
return DictLoss(key, torch.nn.MSELoss())
class TransformedLoss(_Loss):
......
......@@ -526,10 +526,9 @@ class Trainer:
# There is no plan to support the "L2" settings in
# input file before AdamW get merged into pytorch.
raise NotImplementedError('L2 not supported yet')
l2reg.append((0.5 * layer['l2valu'], module))
del layer['l2norm']
del layer['l2valu']
if len(layer) > 0:
if layer:
raise ValueError('unrecognized parameter in layer setup')
i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules)
......@@ -549,7 +548,7 @@ class Trainer:
MSELoss('energies'),
lambda x: 0.5 * (torch.exp(2 * x) - 1) + l2())
if len(params) > 0:
if params:
raise ValueError('unrecognized parameter')
self.global_epoch = 0
......
......@@ -161,8 +161,7 @@ class ChemicalSymbolsToInts:
def __init__(self, all_species):
self.rev_species = {}
for i in range(len(all_species)):
s = all_species[i]
for i, s in enumerate(all_species):
self.rev_species[s] = i
def __call__(self, species):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment