Unverified Commit 5eb73cd8 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

fix order of SAEs (#456)

parent f967877b
......@@ -104,7 +104,7 @@ class Transformations:
"""Convert one reenterable iterable to another reenterable iterable"""
@staticmethod
def species_to_indices(reenterable_iterable, species_order=('H', 'C', 'N', 'O', 'F', 'Cl', 'S')):
def species_to_indices(reenterable_iterable, species_order=('H', 'C', 'N', 'O', 'F', 'S', 'Cl')):
if species_order == 'periodic_table':
species_order = utils.PERIODIC_TABLE
idx = {k: i for i, k in enumerate(species_order)}
......@@ -141,7 +141,10 @@ class Transformations:
if len(counts[s]) != n + 1:
counts[s].append(0)
Y.append(d['energies'])
species = sorted(list(counts.keys()))
# sort based on the order in periodic table
species = sorted(list(counts.keys()), key=lambda x: utils.PERIODIC_TABLE.index(x))
X = [counts[s] for s in species]
if shifter.fit_intercept:
X.append([1] * n)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment