Unverified Commit 63cae6a0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Code quality improvements (#137)

parent 45252be6
FROM zasdfgbnm/pytorch-master
RUN pacman -Sy --noconfirm python-sphinx python2-sphinx python-tqdm python2-tqdm python2-matplotlib python-matplotlib python-pillow python2-pillow flake8
RUN pip install tensorboardX sphinx-gallery ase codecov nose && pip2 install tensorboardX sphinx-gallery ase codecov nose
COPY . /torchani
RUN cd torchani && pip install .
RUN cd torchani && pip2 install .
......@@ -58,8 +58,8 @@ class TestData(unittest.TestCase):
def testTensorShape(self):
for i in self.ds:
input, output = i
species, coordinates = torchani.utils.pad_coordinates(input)
input_, output = i
species, coordinates = torchani.utils.pad_coordinates(input_)
energies = output['energies']
self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size)
......@@ -72,8 +72,8 @@ class TestData(unittest.TestCase):
def testNoUnnecessaryPadding(self):
for i in self.ds:
for input in i[0]:
species, _ = input
for input_ in i[0]:
species, _ = input_
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
......
......@@ -50,8 +50,8 @@ class DictLoss(_Loss):
self.key = key
self.loss = loss
def forward(self, input, other):
return self.loss(input[self.key], other[self.key])
def forward(self, input_, other):
return self.loss(input_[self.key], other[self.key])
class PerAtomDictLoss(DictLoss):
......@@ -60,9 +60,9 @@ class PerAtomDictLoss(DictLoss):
by the caller. Currently the only reduce operation supported is averaging.
"""
def forward(self, input, other):
loss = self.loss(input[self.key], other[self.key])
num_atoms = (input['species'] >= 0).sum(dim=1)
def forward(self, input_, other):
loss = self.loss(input_[self.key], other[self.key])
num_atoms = (input_['species'] >= 0).sum(dim=1)
loss /= num_atoms.to(loss.dtype).to(loss.device)
n = loss.numel()
return loss.sum() / n
......@@ -102,8 +102,8 @@ class TransformedLoss(_Loss):
self.origin = origin
self.transform = transform
def forward(self, input, other):
return self.transform(self.origin(input, other))
def forward(self, input_, other):
return self.transform(self.origin(input_, other))
def RMSEMetric(key):
......
......@@ -101,11 +101,11 @@ def load_atomic_network(filename):
"""Returns an instance of :class:`torch.nn.Sequential` with hyperparameters
and parameters loaded NeuroChem's .nnf, .wparam and .bparam files."""
def decompress_nnf(buffer):
while buffer[0] != b'='[0]:
buffer = buffer[1:]
buffer = buffer[2:]
return bz2.decompress(buffer)[:-1].decode('ascii').strip()
def decompress_nnf(buffer_):
while buffer_[0] != b'='[0]:
buffer_ = buffer_[1:]
buffer_ = buffer_[2:]
return bz2.decompress(buffer_)[:-1].decode('ascii').strip()
def parse_nnf(nnf_file):
# parse input file
......@@ -200,9 +200,9 @@ def load_atomic_network(filename):
networ_dir = os.path.dirname(filename)
with open(filename, 'rb') as f:
buffer = f.read()
buffer = decompress_nnf(buffer)
layer_setups = parse_nnf(buffer)
buffer_ = f.read()
buffer_ = decompress_nnf(buffer_)
layer_setups = parse_nnf(buffer_)
layers = []
for s in layer_setups:
......@@ -225,18 +225,18 @@ def load_atomic_network(filename):
return torch.nn.Sequential(*layers)
def load_model(species, dir):
def load_model(species, dir_):
"""Returns an instance of :class:`torchani.ANIModel` loaded from
NeuroChem's network directory.
Arguments:
species (:class:`collections.abc.Sequence`): Sequence of strings for
chemical symbols of each supported atom type in correct order.
dir (str): String for directory storing network configurations.
dir_ (str): String for directory storing network configurations.
"""
models = []
for i in species:
filename = os.path.join(dir, 'ANN-{}.nnf'.format(i))
filename = os.path.join(dir_, 'ANN-{}.nnf'.format(i))
models.append(load_atomic_network(filename))
return ANIModel(models)
......@@ -439,7 +439,7 @@ class Trainer:
return TreeExec().transform(tree)
def _construct(self, network_setup, params):
dir = os.path.dirname(os.path.abspath(self.filename))
dir_ = os.path.dirname(os.path.abspath(self.filename))
# delete ignored params
def del_if_exists(key):
......@@ -468,14 +468,14 @@ class Trainer:
assert_param('ntwshr', 0)
# load parameters
self.const_file = os.path.join(dir, params['sflparamsfile'])
self.const_file = os.path.join(dir_, params['sflparamsfile'])
self.consts = Constants(self.const_file)
self.aev_computer = AEVComputer(**self.consts)
del params['sflparamsfile']
self.sae_file = os.path.join(dir, params['atomEnergyFile'])
self.sae_file = os.path.join(dir_, params['atomEnergyFile'])
self.shift_energy = load_sae(self.sae_file)
del params['atomEnergyFile']
network_dir = os.path.join(dir, params['ntwkStoreDir'])
network_dir = os.path.join(dir_, params['ntwkStoreDir'])
if not os.path.exists(network_dir):
os.makedirs(network_dir)
self.model_checkpoint = os.path.join(network_dir, self.checkpoint_name)
......
......@@ -43,8 +43,8 @@ class ANIModel(torch.nn.ModuleList):
dtype=aev.dtype)
for i in present_species:
mask = (species_ == i)
input = aev.index_select(0, mask.nonzero().squeeze())
output.masked_scatter_(mask, self[i](input).squeeze())
input_ = aev.index_select(0, mask.nonzero().squeeze())
output.masked_scatter_(mask, self[i](input_).squeeze())
output = output.view_as(species)
return species, self.reducer(output, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment