Unverified Commit 1b58c3c7 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Fix subtract energies (#482)

* Add possibility of sorting in different order to subtract_self_energies

* make species_to_indices behave the same as subtract_self_energies

* Revert "make species_to_indices behave the same as subtract_self_energies"

This reverts commit a415df29bd6e270f3225f0b6721b3c873aeed40e.

* Fix examples

* make species_to_indices behave the same whithout breaking API

* Improve documentation to reflect correct usage

* Fix typos in docstring

* some more docstring issues

* more docstring

* Fix error in species_to_indices
parent 279e53ad
...@@ -95,7 +95,7 @@ except NameError: ...@@ -95,7 +95,7 @@ except NameError:
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560 batch_size = 2560
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices(species_order).shuffle().split(0.8, None) training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter, species_order).species_to_indices(species_order).shuffle().split(0.8, None)
training = training.collate(batch_size).cache() training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache() validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
......
...@@ -52,7 +52,7 @@ batch_size = 2560 ...@@ -52,7 +52,7 @@ batch_size = 2560
training, validation = torchani.data.load( training, validation = torchani.data.load(
dspath, dspath,
additional_properties=('forces',) additional_properties=('forces',)
).subtract_self_energies(energy_shifter).species_to_indices(species_order).shuffle().split(0.8, None) ).subtract_self_energies(energy_shifter, species_order).species_to_indices(species_order).shuffle().split(0.8, None)
training = training.collate(batch_size).cache() training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache() validation = validation.collate(batch_size).cache()
......
...@@ -21,7 +21,13 @@ Available transformations are listed below: ...@@ -21,7 +21,13 @@ Available transformations are listed below:
- `pin_memory` copy the tensor to pinned memory so that later transfer - `pin_memory` copy the tensor to pinned memory so that later transfer
to cuda could be faster. to cuda could be faster.
You can also use `split` to split the iterable to pieces. Use `split` as: By default `species_to_indices` and `subtract_self_energies` order atoms by
atomic number. A special ordering can be used if requested, by calling
`species_to_indices(species_order)` or `subtract_self_energies(energy_shifter,
species_order)` however, this is definitely NOT recommended, it is best to
always order according to atomic number.
you can also use `split` to split the iterable to pieces. use `split` as:
.. code-block:: python .. code-block:: python
...@@ -119,7 +125,7 @@ class Transformations: ...@@ -119,7 +125,7 @@ class Transformations:
return IterableAdapter(reenterable_iterable_factory) return IterableAdapter(reenterable_iterable_factory)
@staticmethod @staticmethod
def subtract_self_energies(reenterable_iterable, self_energies=None): def subtract_self_energies(reenterable_iterable, self_energies=None, species_order=None):
intercept = 0.0 intercept = 0.0
shape_inference = False shape_inference = False
if isinstance(self_energies, utils.EnergyShifter): if isinstance(self_energies, utils.EnergyShifter):
...@@ -142,8 +148,11 @@ class Transformations: ...@@ -142,8 +148,11 @@ class Transformations:
counts[s].append(0) counts[s].append(0)
Y.append(d['energies']) Y.append(d['energies'])
# sort based on the order in periodic table # sort based on the order in periodic table by default
species = sorted(list(counts.keys()), key=lambda x: utils.PERIODIC_TABLE.index(x)) if species_order is None:
species_order = utils.PERIODIC_TABLE
species = sorted(list(counts.keys()), key=lambda x: species_order.index(x))
X = [counts[s] for s in species] X = [counts[s] for s in species]
if shifter.fit_intercept: if shifter.fit_intercept:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment