Unverified Commit 5eb73cd8 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

fix order of SAEs (#456)

parent f967877b
...@@ -104,7 +104,7 @@ class Transformations: ...@@ -104,7 +104,7 @@ class Transformations:
"""Convert one reenterable iterable to another reenterable iterable""" """Convert one reenterable iterable to another reenterable iterable"""
@staticmethod @staticmethod
def species_to_indices(reenterable_iterable, species_order=('H', 'C', 'N', 'O', 'F', 'Cl', 'S')): def species_to_indices(reenterable_iterable, species_order=('H', 'C', 'N', 'O', 'F', 'S', 'Cl')):
if species_order == 'periodic_table': if species_order == 'periodic_table':
species_order = utils.PERIODIC_TABLE species_order = utils.PERIODIC_TABLE
idx = {k: i for i, k in enumerate(species_order)} idx = {k: i for i, k in enumerate(species_order)}
...@@ -141,7 +141,10 @@ class Transformations: ...@@ -141,7 +141,10 @@ class Transformations:
if len(counts[s]) != n + 1: if len(counts[s]) != n + 1:
counts[s].append(0) counts[s].append(0)
Y.append(d['energies']) Y.append(d['energies'])
species = sorted(list(counts.keys()))
# sort based on the order in periodic table
species = sorted(list(counts.keys()), key=lambda x: utils.PERIODIC_TABLE.index(x))
X = [counts[s] for s in species] X = [counts[s] for s in species]
if shifter.fit_intercept: if shifter.fit_intercept:
X.append([1] * n) X.append([1] * n)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment