Commit 2409a22f authored by fanding2000's avatar fanding2000
Browse files

Format fix. More options in readme

parent ce29afea
import re import re
from basic_function import operation from basic_function import operation
from basic_function import data_classes from basic_function import data_classes
from basic_function import chemical_knowledge from basic_function import chemical_knowledge
import copy import copy
def read_xyz_file(file_path): def read_xyz_file(file_path):
input_file = open(file_path, 'r') input_file = open(file_path, 'r')
lines = input_file.readlines() lines = input_file.readlines()
number_of_atoms = int(lines[0]) number_of_atoms = int(lines[0])
name = str(lines[1][:-1]) name = str(lines[1][:-1])
atoms = [] atoms = []
for index,line in enumerate(lines): for index,line in enumerate(lines):
split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) split_line = list(filter(lambda x: x != '', re.split("\\s+", line)))
if len(split_line)==4 and operation.is_number(split_line[1]) and \ if len(split_line)==4 and operation.is_number(split_line[1]) and \
operation.is_number(split_line[2]) and operation.is_number(split_line[3]): operation.is_number(split_line[2]) and operation.is_number(split_line[3]):
atoms.append(data_classes.Atom(element=split_line[0], atoms.append(data_classes.Atom(element=split_line[0],
cart_xyz=[float(split_line[1]), float(split_line[2]), float(split_line[3])], cart_xyz=[float(split_line[1]), float(split_line[2]), float(split_line[3])],
atom_id=index-2)) atom_id=index-2))
if number_of_atoms!=len(atoms): if number_of_atoms!=len(atoms):
print("Warning! The length of atoms don't match the number of atoms given") print("Warning! The length of atoms don't match the number of atoms given")
molecule = data_classes.Molecule(atoms=atoms, name=name, system_name=name) molecule = data_classes.Molecule(atoms=atoms, name=name, system_name=name)
return molecule return molecule
def write_cif_file(crystal, sym=False, name="zcx"): def write_cif_file(crystal, sym=False, name="zcx"):
""" """
Accept crystal class, give the cif file out Accept crystal class, give the cif file out
:param crystal: crystal class :param crystal: crystal class
:param coordinates: frac or cart :param coordinates: frac or cart
:param sym: False:give all atoms out; True:with symmetry :param sym: False:give all atoms out; True:with symmetry
:param name: file name :param name: file name
:return: cif_out :return: cif_out
cif file in list format should be print using the following function: cif file in list format should be print using the following function:
target=open("D:\\zcx.cif",'w') target=open("D:\\zcx.cif",'w')
target.writelines(cif_out) target.writelines(cif_out)
target.close() target.close()
""" """
if crystal.system_name!="unknown": if crystal.system_name!="unknown":
name = crystal.system_name name = crystal.system_name
cif_file = [] cif_file = []
cif_file.append("data_"+str(name)+"\n") cif_file.append("data_"+str(name)+"\n")
if sym==False: if sym==False:
if crystal.space_group==1: if crystal.space_group==1:
crystal_temp = crystal crystal_temp = crystal
else: else:
crystal_temp = copy.deepcopy(crystal) crystal_temp = copy.deepcopy(crystal)
crystal_temp.make_p1() crystal_temp.make_p1()
cif_file.append("_symmetry_space_group_name_H-M \'P1\'"+"\n") cif_file.append("_symmetry_space_group_name_H-M \'P1\'"+"\n")
cif_file.append("_symmetry_Int_Tables_number 1"+"\n") cif_file.append("_symmetry_Int_Tables_number 1"+"\n")
cif_file.append("loop_"+"\n") cif_file.append("loop_"+"\n")
cif_file.append("_symmetry_equiv_pos_site_id"+"\n") cif_file.append("_symmetry_equiv_pos_site_id"+"\n")
cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n")
cif_file.append("1 x,y,z"+"\n") cif_file.append("1 x,y,z"+"\n")
cif_file.append("_cell_length_a "+str(crystal_temp.cell_para[0][0])+"\n") cif_file.append("_cell_length_a "+str(crystal_temp.cell_para[0][0])+"\n")
cif_file.append("_cell_length_b "+str(crystal_temp.cell_para[0][1])+"\n") cif_file.append("_cell_length_b "+str(crystal_temp.cell_para[0][1])+"\n")
cif_file.append("_cell_length_c "+str(crystal_temp.cell_para[0][2])+"\n") cif_file.append("_cell_length_c "+str(crystal_temp.cell_para[0][2])+"\n")
cif_file.append("_cell_angle_alpha "+str(crystal_temp.cell_para[1][0])+"\n") cif_file.append("_cell_angle_alpha "+str(crystal_temp.cell_para[1][0])+"\n")
cif_file.append("_cell_angle_beta "+str(crystal_temp.cell_para[1][1])+"\n") cif_file.append("_cell_angle_beta "+str(crystal_temp.cell_para[1][1])+"\n")
cif_file.append("_cell_angle_gamma "+str(crystal_temp.cell_para[1][2])+"\n") cif_file.append("_cell_angle_gamma "+str(crystal_temp.cell_para[1][2])+"\n")
cif_file.append("loop_"+"\n") cif_file.append("loop_"+"\n")
cif_file.append("_atom_site_label"+"\n") cif_file.append("_atom_site_label"+"\n")
cif_file.append("_atom_site_type_symbol"+"\n") cif_file.append("_atom_site_type_symbol"+"\n")
cif_file.append("_atom_site_fract_x"+"\n") cif_file.append("_atom_site_fract_x"+"\n")
cif_file.append("_atom_site_fract_y"+"\n") cif_file.append("_atom_site_fract_y"+"\n")
cif_file.append("_atom_site_fract_z"+"\n") cif_file.append("_atom_site_fract_z"+"\n")
for i in range(0,len(crystal_temp.atoms)): for i in range(0,len(crystal_temp.atoms)):
cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n"
.format(i+1,crystal_temp.atoms[i].element,crystal_temp.atoms[i].frac_xyz[0], .format(i+1,crystal_temp.atoms[i].element,crystal_temp.atoms[i].frac_xyz[0],
crystal_temp.atoms[i].frac_xyz[1],crystal_temp.atoms[i].frac_xyz[2])) crystal_temp.atoms[i].frac_xyz[1],crystal_temp.atoms[i].frac_xyz[2]))
return cif_file return cif_file
elif sym==True: elif sym==True:
cif_file.append("_symmetry_space_group_name_H-M \'{}\'".format(chemical_knowledge.space_group[crystal.space_group][1])+"\n") cif_file.append("_symmetry_space_group_name_H-M \'{}\'".format(chemical_knowledge.space_group[crystal.space_group][1])+"\n")
cif_file.append("_symmetry_Int_Tables_number {}".format(crystal.space_group)+"\n") cif_file.append("_symmetry_Int_Tables_number {}".format(crystal.space_group)+"\n")
cif_file.append("loop_"+"\n") cif_file.append("loop_"+"\n")
cif_file.append("_symmetry_equiv_pos_site_id"+"\n") cif_file.append("_symmetry_equiv_pos_site_id"+"\n")
cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n")
for idx, SYMM in enumerate(crystal.SYMM): for idx, SYMM in enumerate(crystal.SYMM):
cif_file.append("{} {}".format(idx+1,SYMM)+"\n") cif_file.append("{} {}".format(idx+1,SYMM)+"\n")
cif_file.append("_cell_length_a "+str(crystal.cell_para[0][0])+"\n") cif_file.append("_cell_length_a "+str(crystal.cell_para[0][0])+"\n")
cif_file.append("_cell_length_b "+str(crystal.cell_para[0][1])+"\n") cif_file.append("_cell_length_b "+str(crystal.cell_para[0][1])+"\n")
cif_file.append("_cell_length_c "+str(crystal.cell_para[0][2])+"\n") cif_file.append("_cell_length_c "+str(crystal.cell_para[0][2])+"\n")
cif_file.append("_cell_angle_alpha "+str(crystal.cell_para[1][0])+"\n") cif_file.append("_cell_angle_alpha "+str(crystal.cell_para[1][0])+"\n")
cif_file.append("_cell_angle_beta "+str(crystal.cell_para[1][1])+"\n") cif_file.append("_cell_angle_beta "+str(crystal.cell_para[1][1])+"\n")
cif_file.append("_cell_angle_gamma "+str(crystal.cell_para[1][2])+"\n") cif_file.append("_cell_angle_gamma "+str(crystal.cell_para[1][2])+"\n")
cif_file.append("loop_"+"\n") cif_file.append("loop_"+"\n")
cif_file.append("_atom_site_label"+"\n") cif_file.append("_atom_site_label"+"\n")
cif_file.append("_atom_site_type_symbol"+"\n") cif_file.append("_atom_site_type_symbol"+"\n")
cif_file.append("_atom_site_fract_x"+"\n") cif_file.append("_atom_site_fract_x"+"\n")
cif_file.append("_atom_site_fract_y"+"\n") cif_file.append("_atom_site_fract_y"+"\n")
cif_file.append("_atom_site_fract_z"+"\n") cif_file.append("_atom_site_fract_z"+"\n")
for i in range(0,len(crystal.atoms)): for i in range(0,len(crystal.atoms)):
cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n"
.format(i+1,crystal.atoms[i].element,crystal.atoms[i].frac_xyz[0], .format(i+1,crystal.atoms[i].element,crystal.atoms[i].frac_xyz[0],
crystal.atoms[i].frac_xyz[1],crystal.atoms[i].frac_xyz[2])) crystal.atoms[i].frac_xyz[1],crystal.atoms[i].frac_xyz[2]))
return cif_file return cif_file
def write_cifs_file(crystals, sym=False, name="zcx"): def write_cifs_file(crystals, sym=False, name="zcx"):
cifs_file = [] cifs_file = []
for crystal in crystals: for crystal in crystals:
single_cif = write_cif_file(crystal,sym=sym, name=name) single_cif = write_cif_file(crystal,sym=sym, name=name)
cifs_file.extend(single_cif) cifs_file.extend(single_cif)
return cifs_file return cifs_file
def read_cif_file(file_path,on_sym_check=False,shut_up=False,system_name="unknown",comment_name="unknown"): def read_cif_file(file_path,on_sym_check=False,shut_up=False,system_name="unknown",comment_name="unknown"):
input_file = open(file_path, 'r') input_file = open(file_path, 'r')
lines = input_file.readlines() lines = input_file.readlines()
step_pickle = [] step_pickle = []
crystal_all = [] crystal_all = []
if system_name=="unknown": if system_name=="unknown":
no_name = True no_name = True
else: else:
no_name = False no_name = False
# first time scan # first time scan
for index,line in enumerate(lines): for index,line in enumerate(lines):
# find out all the step pickle # find out all the step pickle
if line.startswith("data_"): if line.startswith("data_"):
step_pickle.append(index) step_pickle.append(index)
step_pickle.append(len(lines)) step_pickle.append(len(lines))
# treat every step and return a crystal # treat every step and return a crystal
for m in range(0,len(step_pickle)-1): for m in range(0,len(step_pickle)-1):
atoms = [] atoms = []
atoms_P1 = [] atoms_P1 = []
SYMM = [] SYMM = []
cell_para = [["unknown","unknown","unknown"],["unknown","unknown","unknown"]] cell_para = [["unknown","unknown","unknown"],["unknown","unknown","unknown"]]
for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]):
split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) split_line = list(filter(lambda x: x != '', re.split("\\s+", line)))
if line.startswith("#"): if line.startswith("#"):
pass pass
elif len(split_line)==0: elif len(split_line)==0:
pass pass
# read the loop of symmetry # read the loop of symmetry
# elif split_line[0]=="loop_" and lines[step_pickle[m]+index+1]=="_symmetry_equiv_pos_as_xyz\n": # elif split_line[0]=="loop_" and lines[step_pickle[m]+index+1]=="_symmetry_equiv_pos_as_xyz\n":
elif split_line[0] == "loop_" and lines[step_pickle[m] + index + 1] == "_symmetry_equiv_pos_as_xyz\n": elif split_line[0] == "loop_" and lines[step_pickle[m] + index + 1] == "_symmetry_equiv_pos_as_xyz\n":
temp_number = 1 temp_number = 1
while "_" not in lines[step_pickle[m]+index+1+temp_number]: while "_" not in lines[step_pickle[m]+index+1+temp_number]:
split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))
temp_number+=1 temp_number+=1
if not operation.is_number(split_line_temp[0]): if not operation.is_number(split_line_temp[0]):
SYMM.append(split_line_temp[0]) SYMM.append(split_line_temp[0])
else: else:
SYMM.append(split_line_temp[1]) SYMM.append(split_line_temp[1])
elif split_line[0] == "loop_" and lines[step_pickle[m]+index+2].strip(" ")=="_symmetry_equiv_pos_as_xyz\n": elif split_line[0] == "loop_" and lines[step_pickle[m]+index+2].strip(" ")=="_symmetry_equiv_pos_as_xyz\n":
temp_number = 1 temp_number = 1
while "_" not in lines[step_pickle[m]+index+2+temp_number]: while "_" not in lines[step_pickle[m]+index+2+temp_number]:
split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+2+temp_number]))) split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+2+temp_number])))
temp_number+=1 temp_number+=1
SYMM.append("".join(split_line_temp[1:])) SYMM.append("".join(split_line_temp[1:]))
elif split_line[0] == "loop_" and "_space_group_symop_operation_xyz\n" in lines[step_pickle[m] + index + 1]: elif split_line[0] == "loop_" and "_space_group_symop_operation_xyz\n" in lines[step_pickle[m] + index + 1]:
# ase format # ase format
temp_number = 1 temp_number = 1
while "_" not in lines[step_pickle[m]+index+1+temp_number]: while "_" not in lines[step_pickle[m]+index+1+temp_number]:
if lines[step_pickle[m]+index+1+temp_number]=="\n": if lines[step_pickle[m]+index+1+temp_number]=="\n":
temp_number += 1 temp_number += 1
continue continue
split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))
temp_number+=1 temp_number+=1
if not operation.is_number(split_line_temp[0]): if not operation.is_number(split_line_temp[0]):
SYMM.append("".join(split_line_temp)) SYMM.append("".join(split_line_temp))
# read the loop of atoms: # read the loop of atoms:
elif (split_line[0] == "loop_" and lines[step_pickle[m] + index + 1].strip(" ") == "_atom_site_label\n") or \ elif (split_line[0] == "loop_" and lines[step_pickle[m] + index + 1].strip(" ") == "_atom_site_label\n") or \
(split_line[0] == "loop_" and lines[step_pickle[m] + index + 2].strip(" ") == "_atom_site_label\n"): (split_line[0] == "loop_" and lines[step_pickle[m] + index + 2].strip(" ") == "_atom_site_label\n"):
temp_number = 0 temp_number = 0
while "_" in lines[step_pickle[m]+index+1+temp_number]: while "_" in lines[step_pickle[m]+index+1+temp_number]:
if lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_type_symbol\n": if lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_type_symbol\n":
ele_pos = temp_number ele_pos = temp_number
elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_x\n": elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_x\n":
x_pos = temp_number x_pos = temp_number
elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_y\n": elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_y\n":
y_pos = temp_number y_pos = temp_number
elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_z\n": elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_z\n":
z_pos = temp_number z_pos = temp_number
temp_number+=1 temp_number+=1
how_long = temp_number how_long = temp_number
while len(list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))) == how_long: while len(list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))) == how_long:
split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))
atoms.append(data_classes.Atom(element=split_line_temp[ele_pos], atoms.append(data_classes.Atom(element=split_line_temp[ele_pos],
frac_xyz=[float(split_line_temp[x_pos]),float(split_line_temp[y_pos]), frac_xyz=[float(split_line_temp[x_pos]),float(split_line_temp[y_pos]),
float(split_line_temp[z_pos])])) float(split_line_temp[z_pos])]))
temp_number += 1 temp_number += 1
if step_pickle[m]+index+1+temp_number==len(lines): if step_pickle[m]+index+1+temp_number==len(lines):
break break
elif split_line[0] == "_cell_length_a": elif split_line[0] == "_cell_length_a":
cell_para[0][0] = float(split_line[1]) cell_para[0][0] = float(split_line[1])
elif split_line[0] == "_cell_length_b": elif split_line[0] == "_cell_length_b":
cell_para[0][1] = float(split_line[1]) cell_para[0][1] = float(split_line[1])
elif split_line[0] == "_cell_length_c": elif split_line[0] == "_cell_length_c":
cell_para[0][2] = float(split_line[1]) cell_para[0][2] = float(split_line[1])
elif split_line[0] == "_cell_angle_alpha": elif split_line[0] == "_cell_angle_alpha":
cell_para[1][0] = float(split_line[1]) cell_para[1][0] = float(split_line[1])
elif split_line[0] == "_cell_angle_beta": elif split_line[0] == "_cell_angle_beta":
cell_para[1][1] = float(split_line[1]) cell_para[1][1] = float(split_line[1])
elif split_line[0] == "_cell_angle_gamma": elif split_line[0] == "_cell_angle_gamma":
cell_para[1][2] = float(split_line[1]) cell_para[1][2] = float(split_line[1])
elif "data_" in line: elif "data_" in line:
if no_name == True: if no_name == True:
system_name = line[5:] system_name = line[5:]
system_name = system_name.replace(" ","_") system_name = system_name.replace(" ","_")
system_name = system_name.replace("\\", "_") system_name = system_name.replace("\\", "_")
system_name = system_name.replace("\n", "") system_name = system_name.replace("\n", "")
for atom in atoms: for atom in atoms:
all_reflect_position = operation.space_group_transfer_for_single_atom(atom.frac_xyz, SYMM) all_reflect_position = operation.space_group_transfer_for_single_atom(atom.frac_xyz, SYMM)
for new_position in all_reflect_position: for new_position in all_reflect_position:
atoms_P1.append(data_classes.Atom(element=atom.element, atoms_P1.append(data_classes.Atom(element=atom.element,
frac_xyz=[new_position[0], new_position[1], new_position[2]])) frac_xyz=[new_position[0], new_position[1], new_position[2]]))
crystal_all.append(data_classes.Crystal(cell_para=cell_para, atoms=atoms_P1, comment=comment_name, system_name=system_name)) crystal_all.append(data_classes.Crystal(cell_para=cell_para, atoms=atoms_P1, comment=comment_name, system_name=system_name))
if on_sym_check == True: if on_sym_check == True:
raise Exception("Not finished part, TODO in code") raise Exception("Not finished part, TODO in code")
if shut_up==False: if shut_up==False:
if m%100 == 0: if m%100 == 0:
print("{} structures have been treated".format(m)) print("{} structures have been treated".format(m))
return crystal_all return crystal_all
def write_poscar_file(crystal, coordinates = 'frac', name = "parser_zcx_create"): def write_poscar_file(crystal, coordinates = 'frac', name = "parser_zcx_create"):
vasp_file = [] vasp_file = []
vasp_file.append('{}\n'.format(name)) vasp_file.append('{}\n'.format(name))
vasp_file.append('1.0\n') vasp_file.append('1.0\n')
cell_vect = crystal.cell_vect cell_vect = crystal.cell_vect
for vect in cell_vect: for vect in cell_vect:
vasp_file.append("{:16.8f} {:16.8f} {:16.8f}\n".format(vect[0],vect[1],vect[2])) vasp_file.append("{:16.8f} {:16.8f} {:16.8f}\n".format(vect[0],vect[1],vect[2]))
crystal.sort_by_element() crystal.sort_by_element()
vasp_file.append("".join("{:>6s}".format(x) for x in crystal.get_element()) + "\n") vasp_file.append("".join("{:>6s}".format(x) for x in crystal.get_element()) + "\n")
vasp_file.append("".join("{:>6.0f}".format(x) for x in crystal.get_element_amount()) + "\n") vasp_file.append("".join("{:>6.0f}".format(x) for x in crystal.get_element_amount()) + "\n")
if coordinates == 'frac': if coordinates == 'frac':
vasp_file.append('Direct\n') vasp_file.append('Direct\n')
for ELEMENT in crystal.get_element(): for ELEMENT in crystal.get_element():
for ATOM in crystal.atoms: for ATOM in crystal.atoms:
if ATOM.element == ELEMENT: if ATOM.element == ELEMENT:
vasp_file.append( vasp_file.append(
"{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.frac_xyz[0], ATOM.frac_xyz[1], ATOM.frac_xyz[2])) "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.frac_xyz[0], ATOM.frac_xyz[1], ATOM.frac_xyz[2]))
elif coordinates == 'cart': elif coordinates == 'cart':
vasp_file.append('Cartesian\n') vasp_file.append('Cartesian\n')
for ELEMENT in crystal.get_element(): for ELEMENT in crystal.get_element():
for ATOM in crystal.atoms: for ATOM in crystal.atoms:
if ATOM.element == ELEMENT: if ATOM.element == ELEMENT:
vasp_file.append( vasp_file.append(
"{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.cart_xyz[0], ATOM.cart_xyz[1], ATOM.cart_xyz[2])) "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.cart_xyz[0], ATOM.cart_xyz[1], ATOM.cart_xyz[2]))
else: else:
raise Exception("Wrong coordinates type: {}".format(coordinates)) raise Exception("Wrong coordinates type: {}".format(coordinates))
return vasp_file return vasp_file
def read_ase_pbc_file(file_path,shut_up=False): def read_ase_pbc_file(file_path,shut_up=False):
input_file = open(file_path, 'r') input_file = open(file_path, 'r')
lines = input_file.readlines()[2:] lines = input_file.readlines()[2:]
step_pickle = [] step_pickle = []
crystal_all = [] crystal_all = []
# first time scan # first time scan
for index,line in enumerate(lines): for index,line in enumerate(lines):
# find out all the step pickle # find out all the step pickle
if line.startswith("Step "): if line.startswith("Step "):
step_pickle.append(index) step_pickle.append(index)
step_pickle.append(len(lines)) step_pickle.append(len(lines))
# treat every step and return a crystal # treat every step and return a crystal
for m in range(0,len(step_pickle)-1): for m in range(0,len(step_pickle)-1):
atoms_P1 = [] atoms_P1 = []
force_matrix = [] force_matrix = []
position_matrix = [] position_matrix = []
in_forces = False in_forces = False
in_positions = False in_positions = False
for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]):
# split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) # split_line = list(filter(lambda x: x != '', re.split("\\s+", line)))
line = line.strip() line = line.strip()
# check Forces part # check Forces part
if line.startswith("Forces:"): if line.startswith("Forces:"):
in_forces = True in_forces = True
in_positions = False in_positions = False
continue continue
# check Positions part # check Positions part
if line.startswith("Positions:"): if line.startswith("Positions:"):
in_positions = True in_positions = True
in_forces = False in_forces = False
continue continue
if in_forces and line.startswith("[") and line.endswith("]"): if in_forces and line.startswith("[") and line.endswith("]"):
line = line.replace("[", "").replace("]", "") line = line.replace("[", "").replace("]", "")
force_matrix.append([float(x) for x in line.split()]) force_matrix.append([float(x) for x in line.split()])
# analyse Positions part # analyse Positions part
if in_positions and line.startswith("[") and line.endswith("]"): if in_positions and line.startswith("[") and line.endswith("]"):
line = line.replace("[", "").replace("]", "") line = line.replace("[", "").replace("]", "")
position_matrix.append([float(x) for x in line.split()]) position_matrix.append([float(x) for x in line.split()])
if line.startswith("Elements:"): if line.startswith("Elements:"):
elements_string = line.strip().split(":", 1)[-1].strip() elements_string = line.strip().split(":", 1)[-1].strip()
elements_string = elements_string[1:-1] elements_string = elements_string[1:-1]
elements = [elem.strip().strip("'") for elem in elements_string.split(",")] elements = [elem.strip().strip("'") for elem in elements_string.split(",")]
if line.startswith("cell:"): if line.startswith("cell:"):
matrix_string = line[len("cell: Cell("):-1] matrix_string = line[len("cell: Cell("):-1]
rows = matrix_string.split("], [") rows = matrix_string.split("], [")
cell_vect = [ cell_vect = [
[float(value) for value in row.replace('[', '').replace(']', '').replace(')', '').split(", ")] [float(value) for value in row.replace('[', '').replace(']', '').replace(')', '').split(", ")]
for row in rows for row in rows
] ]
for i in range(0,len(elements)): for i in range(0,len(elements)):
atoms_P1.append(data_classes.Atom(element=elements[i], cart_xyz=[position_matrix[i][0], position_matrix[i][1], position_matrix[i][2]])) atoms_P1.append(data_classes.Atom(element=elements[i], cart_xyz=[position_matrix[i][0], position_matrix[i][1], position_matrix[i][2]]))
crystal_all.append(data_classes.Crystal(cell_vect=cell_vect, atoms=atoms_P1)) crystal_all.append(data_classes.Crystal(cell_vect=cell_vect, atoms=atoms_P1))
return crystal_all return crystal_all
\ No newline at end of file
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
A collection of functions for performing crystallographic and molecular operations, A collection of functions for performing crystallographic and molecular operations,
such as symmetry application, supercell generation, and geometric analysis. such as symmetry application, supercell generation, and geometric analysis.
""" """
# --- Standard Library Imports --- # --- Standard Library Imports ---
import copy import copy
import fractions import fractions
import re import re
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
# --- Third-Party Imports --- # --- Third-Party Imports ---
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from scipy.spatial import cKDTree as KDTree from scipy.spatial import cKDTree as KDTree
# --- Local Application Imports --- # --- Local Application Imports ---
from basic_function import chemical_knowledge, data_classes from basic_function import chemical_knowledge, data_classes
# Type aliases for clarity # Type aliases for clarity
NDArrayFloat = npt.NDArray[np.float64] NDArrayFloat = npt.NDArray[np.float64]
CellVectors = List[List[float]] CellVectors = List[List[float]]
SymmetryOperations = List[str] SymmetryOperations = List[str]
def is_number(s: str) -> bool: def is_number(s: str) -> bool:
"""Checks if a string can be interpreted as a number (float or fraction). """Checks if a string can be interpreted as a number (float or fraction).
Args: Args:
s: The input string. s: The input string.
Returns: Returns:
True if the string represents a number, False otherwise. True if the string represents a number, False otherwise.
""" """
try: try:
float(s) float(s)
return True return True
except ValueError: except ValueError:
pass pass
try: try:
# Check for fractional representations like "1/2" # Check for fractional representations like "1/2"
float(fractions.Fraction(s)) float(fractions.Fraction(s))
return True return True
except ValueError: except ValueError:
return False return False
def _parse_symmetry_operations( def _parse_symmetry_operations(
sym_ops: SymmetryOperations, sym_ops: SymmetryOperations,
) -> Tuple[List[NDArrayFloat], List[NDArrayFloat]]: ) -> Tuple[List[NDArrayFloat], List[NDArrayFloat]]:
"""Parses a list of symmetry operation strings into matrices. """Parses a list of symmetry operation strings into matrices.
This is an internal helper function to avoid code duplication in public functions. This is an internal helper function to avoid code duplication in public functions.
Args: Args:
sym_ops: A list of symmetry operation strings (e.g., ['x, y, z+1/2']). sym_ops: A list of symmetry operation strings (e.g., ['x, y, z+1/2']).
Returns: Returns:
A tuple containing two lists: A tuple containing two lists:
- A list of 3x3 rotation/reflection matrices (M). - A list of 3x3 rotation/reflection matrices (M).
- A list of 1x3 translation vectors (C). - A list of 1x3 translation vectors (C).
Raises: Raises:
ValueError: If a symmetry operation string is malformed. ValueError: If a symmetry operation string is malformed.
""" """
rotation_matrices = [] rotation_matrices = []
translation_vectors = [] translation_vectors = []
for sym_op_str in sym_ops: for sym_op_str in sym_ops:
sym_op_parts = sym_op_str.lower().replace(" ", "").split(",") sym_op_parts = sym_op_str.lower().replace(" ", "").split(",")
if len(sym_op_parts) != 3: if len(sym_op_parts) != 3:
raise ValueError(f"Symmetry operation '{sym_op_str}' is invalid.") raise ValueError(f"Symmetry operation '{sym_op_str}' is invalid.")
matrix_m = np.zeros((3, 3)) matrix_m = np.zeros((3, 3))
matrix_c = np.zeros((1, 3)) matrix_c = np.zeros((1, 3))
for i, part in enumerate(sym_op_parts): for i, part in enumerate(sym_op_parts):
# Regex to find elements like '+x', '-y', 'z', '1/2', '-0.5' # Regex to find elements like '+x', '-y', 'z', '1/2', '-0.5'
tokens = re.findall(r"([+-]?[xyz0-9./]+)", part) tokens = re.findall(r"([+-]?[xyz0-9./]+)", part)
for token in tokens: for token in tokens:
token = token.strip() token = token.strip()
if not token: if not token:
continue continue
if "x" in token: if "x" in token:
matrix_m[0, i] = -1.0 if token.startswith("-") else 1.0 matrix_m[0, i] = -1.0 if token.startswith("-") else 1.0
elif "y" in token: elif "y" in token:
matrix_m[1, i] = -1.0 if token.startswith("-") else 1.0 matrix_m[1, i] = -1.0 if token.startswith("-") else 1.0
elif "z" in token: elif "z" in token:
matrix_m[2, i] = -1.0 if token.startswith("-") else 1.0 matrix_m[2, i] = -1.0 if token.startswith("-") else 1.0
elif is_number(token): elif is_number(token):
matrix_c[0, i] += float(fractions.Fraction(token)) matrix_c[0, i] += float(fractions.Fraction(token))
else: else:
raise ValueError(f"Invalid fragment '{token}' in symmetry operation.") raise ValueError(f"Invalid fragment '{token}' in symmetry operation.")
rotation_matrices.append(matrix_m) rotation_matrices.append(matrix_m)
translation_vectors.append(matrix_c) translation_vectors.append(matrix_c)
return rotation_matrices, translation_vectors return rotation_matrices, translation_vectors
def space_group_transfer_for_single_atom( def space_group_transfer_for_single_atom(
frac_xyz: List[float], space_group_ops: SymmetryOperations frac_xyz: List[float], space_group_ops: SymmetryOperations
) -> List[List[float]]: ) -> List[List[float]]:
"""Applies space group symmetry operations to a single atomic coordinate. """Applies space group symmetry operations to a single atomic coordinate.
Args: Args:
frac_xyz: The fractional coordinates [x, y, z] of a single atom. frac_xyz: The fractional coordinates [x, y, z] of a single atom.
space_group_ops: A list of space group symmetry operation strings. space_group_ops: A list of space group symmetry operation strings.
Returns: Returns:
A list of all symmetrically equivalent fractional coordinates. A list of all symmetrically equivalent fractional coordinates.
""" """
rot_matrices, trans_vectors = _parse_symmetry_operations(space_group_ops) rot_matrices, trans_vectors = _parse_symmetry_operations(space_group_ops)
equivalent_positions = [] equivalent_positions = []
atom_pos = np.array(frac_xyz) atom_pos = np.array(frac_xyz)
for rot, trans in zip(rot_matrices, trans_vectors): for rot, trans in zip(rot_matrices, trans_vectors):
new_pos = np.dot(atom_pos, rot.T) + trans.squeeze() new_pos = np.dot(atom_pos, rot.T) + trans.squeeze()
equivalent_positions.append(new_pos.tolist()) equivalent_positions.append(new_pos.tolist())
return equivalent_positions return equivalent_positions
def super_cell( def super_cell(
crystal: "data_classes.Crystal", crystal: "data_classes.Crystal",
cell_range: Optional[List[List[int]]] = None, cell_range: Optional[List[List[int]]] = None,
) -> "data_classes.Crystal": ) -> "data_classes.Crystal":
"""Constructs a supercell from a unit cell. """Constructs a supercell from a unit cell.
Args: Args:
crystal: The input Crystal object. crystal: The input Crystal object.
cell_range: A list of ranges for each lattice vector, e.g., cell_range: A list of ranges for each lattice vector, e.g.,
[[-1, 1], [-1, 1], [-1, 1]] creates a 3x3x3 supercell. [[-1, 1], [-1, 1], [-1, 1]] creates a 3x3x3 supercell.
If None, defaults to [[-1, 1], [-1, 1], [-1, 1]]. If None, defaults to [[-1, 1], [-1, 1], [-1, 1]].
Returns: Returns:
A new Crystal object representing the supercell. A new Crystal object representing the supercell.
""" """
if cell_range is None: if cell_range is None:
cell_range = [[-1, 1], [-1, 1], [-1, 1]] cell_range = [[-1, 1], [-1, 1], [-1, 1]]
dims = [r[1] - r[0] + 1 for r in cell_range] dims = [r[1] - r[0] + 1 for r in cell_range]
new_lattice = [ new_lattice = [
[dim * val for val in crystal.cell_vect[i]] for i, dim in enumerate(dims) [dim * val for val in crystal.cell_vect[i]] for i, dim in enumerate(dims)
] ]
translation_vectors = [] translation_vectors = []
for h in range(cell_range[0][0], cell_range[0][1] + 1): for h in range(cell_range[0][0], cell_range[0][1] + 1):
for k in range(cell_range[1][0], cell_range[1][1] + 1): for k in range(cell_range[1][0], cell_range[1][1] + 1):
for l in range(cell_range[2][0], cell_range[2][1] + 1): for l in range(cell_range[2][0], cell_range[2][1] + 1):
translation_vectors.append([h, k, l]) translation_vectors.append([h, k, l])
new_atoms = [] new_atoms = []
for atom in crystal.atoms: for atom in crystal.atoms:
for trans_vec in translation_vectors: for trans_vec in translation_vectors:
new_frac_xyz = [ new_frac_xyz = [
(atom.frac_xyz[i] + trans_vec[i]) / dims[i] for i in range(3) (atom.frac_xyz[i] + trans_vec[i]) / dims[i] for i in range(3)
] ]
new_atoms.append( new_atoms.append(
data_classes.Atom(element=atom.element, frac_xyz=new_frac_xyz) data_classes.Atom(element=atom.element, frac_xyz=new_frac_xyz)
) )
if crystal.energy != "unknown": if crystal.energy != "unknown":
total_cells = dims[0] * dims[1] * dims[2] total_cells = dims[0] * dims[1] * dims[2]
new_energy = crystal.energy * total_cells new_energy = crystal.energy * total_cells
else: else:
new_energy = "unknown" new_energy = "unknown"
return data_classes.Crystal( return data_classes.Crystal(
cell_vect=new_lattice, energy=new_energy, atoms=new_atoms cell_vect=new_lattice, energy=new_energy, atoms=new_atoms
) )
def orient_molecule(molecule: "data_classes.Molecule") -> "data_classes.Molecule": def orient_molecule(molecule: "data_classes.Molecule") -> "data_classes.Molecule":
"""Orients a molecule along its principal axes of inertia. """Orients a molecule along its principal axes of inertia.
The method uses the Moment of Inertia tensor to define a canonical orientation. The method uses the Moment of Inertia tensor to define a canonical orientation.
The molecule's coordinates are modified in-place. For more details, see: The molecule's coordinates are modified in-place. For more details, see:
http://sobereva.com/426 http://sobereva.com/426
Args: Args:
molecule: The Molecule object to be oriented. molecule: The Molecule object to be oriented.
Returns: Returns:
The same Molecule object with its atoms reoriented. The same Molecule object with its atoms reoriented.
""" """
all_ele, all_cart = molecule.get_ele_and_cart() all_ele, all_cart = molecule.get_ele_and_cart()
if len(all_cart) <= 1: if len(all_cart) <= 1:
return molecule # No orientation needed for single atoms or empty molecules. return molecule # No orientation needed for single atoms or empty molecules.
masses = np.array([chemical_knowledge.element_masses[el] for el in all_ele]) masses = np.array([chemical_knowledge.element_masses[el] for el in all_ele])
relative_position = all_cart - molecule.get_center_of_mass() relative_position = all_cart - molecule.get_center_of_mass()
# Calculate the moment of inertia tensor # Calculate the moment of inertia tensor
I_xx = np.sum(masses * (relative_position[:, 1] ** 2 + relative_position[:, 2] ** 2)) I_xx = np.sum(masses * (relative_position[:, 1] ** 2 + relative_position[:, 2] ** 2))
I_yy = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 2] ** 2)) I_yy = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 2] ** 2))
I_zz = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 1] ** 2)) I_zz = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 1] ** 2))
I_xy = -np.sum(masses * relative_position[:, 0] * relative_position[:, 1]) I_xy = -np.sum(masses * relative_position[:, 0] * relative_position[:, 1])
I_xz = -np.sum(masses * relative_position[:, 0] * relative_position[:, 2]) I_xz = -np.sum(masses * relative_position[:, 0] * relative_position[:, 2])
I_yz = -np.sum(masses * relative_position[:, 1] * relative_position[:, 2]) I_yz = -np.sum(masses * relative_position[:, 1] * relative_position[:, 2])
I_matrix = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]]) I_matrix = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]])
# Eigenvectors of the inertia tensor are the principal axes. # Eigenvectors of the inertia tensor are the principal axes.
# np.linalg.eigh is used for symmetric matrices. # np.linalg.eigh is used for symmetric matrices.
eigenvalues, eigenvectors = np.linalg.eigh(I_matrix) eigenvalues, eigenvectors = np.linalg.eigh(I_matrix)
principal_axes = eigenvectors.T principal_axes = eigenvectors.T
# Project the relative positions onto the new axes system. # Project the relative positions onto the new axes system.
new_positions = np.dot(relative_position, principal_axes.T) new_positions = np.dot(relative_position, principal_axes.T)
molecule.put_ele_cart_back(all_ele, new_positions) molecule.put_ele_cart_back(all_ele, new_positions)
return molecule return molecule
def get_rotate_matrix(v: NDArrayFloat) -> NDArrayFloat: def get_rotate_matrix(v: NDArrayFloat) -> NDArrayFloat:
"""Generates a 3x3 rotation matrix from a 3D vector `v`. """Generates a 3x3 rotation matrix from a 3D vector `v`.
This function uses a mapping from a 3D vector to a quaternion, which is then This function uses a mapping from a 3D vector to a quaternion, which is then
used to construct the rotation matrix. This method avoids gimbal lock. A used to construct the rotation matrix. This method avoids gimbal lock. A
left-handed coordinate system is assumed. left-handed coordinate system is assumed.
Args: Args:
v: A 3-element NumPy array used to generate the quaternion. v: A 3-element NumPy array used to generate the quaternion.
Returns: Returns:
A 3x3 rotation matrix. A 3x3 rotation matrix.
""" """
# Ensure v elements are within valid ranges if necessary, though the # Ensure v elements are within valid ranges if necessary, though the
# formulas handle most inputs gracefully. # formulas handle most inputs gracefully.
v0_sqrt = np.sqrt(max(v[0], 0)) v0_sqrt = np.sqrt(max(v[0], 0))
v0_1_sqrt = np.sqrt(max(1.0 - v[0], 0)) v0_1_sqrt = np.sqrt(max(1.0 - v[0], 0))
angle1 = 2.0 * np.pi * v[1] angle1 = 2.0 * np.pi * v[1]
angle2 = 2.0 * np.pi * v[2] angle2 = 2.0 * np.pi * v[2]
# Quaternion components (x, y, z, w) # Quaternion components (x, y, z, w)
qx = v0_1_sqrt * np.sin(angle1) qx = v0_1_sqrt * np.sin(angle1)
qy = v0_1_sqrt * np.cos(angle1) qy = v0_1_sqrt * np.cos(angle1)
qz = v0_sqrt * np.sin(angle2) qz = v0_sqrt * np.sin(angle2)
qw = v0_sqrt * np.cos(angle2) qw = v0_sqrt * np.cos(angle2)
return np.array([ return np.array([
[1 - 2*qy**2 - 2*qz**2, 2*qx*qy + 2*qw*qz, 2*qx*qz - 2*qw*qy], [1 - 2*qy**2 - 2*qz**2, 2*qx*qy + 2*qw*qz, 2*qx*qz - 2*qw*qy],
[2*qx*qy - 2*qw*qz, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz + 2*qw*qx], [2*qx*qy - 2*qw*qz, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz + 2*qw*qx],
[2*qx*qz + 2*qw*qy, 2*qy*qz - 2*qw*qx, 1 - 2*qx**2 - 2*qy**2] [2*qx*qz + 2*qw*qy, 2*qy*qz - 2*qw*qx, 1 - 2*qx**2 - 2*qy**2]
]) ])
def f2c_matrix( def f2c_matrix(
cell_params: Tuple[List[float], List[float]] cell_params: Tuple[List[float], List[float]]
) -> Optional[NDArrayFloat]: ) -> Optional[NDArrayFloat]:
"""Calculates the fractional-to-Cartesian transformation matrix. """Calculates the fractional-to-Cartesian transformation matrix.
Args: Args:
cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]],
where lengths are in Angstroms and angles are in degrees. where lengths are in Angstroms and angles are in degrees.
Returns: Returns:
The 3x3 transformation matrix, or None if cell parameters are invalid. The 3x3 transformation matrix, or None if cell parameters are invalid.
""" """
lengths, angles = cell_params lengths, angles = cell_params
a, b, c = lengths a, b, c = lengths
alpha, beta, gamma = np.deg2rad(angles) alpha, beta, gamma = np.deg2rad(angles)
cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma])
sin_g = np.sin(gamma) sin_g = np.sin(gamma)
# Volume calculation term # Volume calculation term
volume_term_sq = ( volume_term_sq = (
1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g
) )
if volume_term_sq < 0: if volume_term_sq < 0:
return None return None
volume = a * b * c * np.sqrt(volume_term_sq) volume = a * b * c * np.sqrt(volume_term_sq)
matrix = np.zeros((3, 3)) matrix = np.zeros((3, 3))
matrix[0, 0] = a matrix[0, 0] = a
matrix[0, 1] = b * cos_g matrix[0, 1] = b * cos_g
matrix[0, 2] = c * cos_b matrix[0, 2] = c * cos_b
matrix[1, 1] = b * sin_g matrix[1, 1] = b * sin_g
matrix[1, 2] = c * (cos_a - cos_b * cos_g) / sin_g matrix[1, 2] = c * (cos_a - cos_b * cos_g) / sin_g
matrix[2, 2] = volume / (a * b * sin_g) matrix[2, 2] = volume / (a * b * sin_g)
return matrix.T return matrix.T
def c2f_matrix( def c2f_matrix(
cell_params: Tuple[List[float], List[float]] cell_params: Tuple[List[float], List[float]]
) -> Optional[NDArrayFloat]: ) -> Optional[NDArrayFloat]:
"""Calculates the Cartesian-to-fractional transformation matrix. """Calculates the Cartesian-to-fractional transformation matrix.
This is the inverse of the matrix generated by `f2c_matrix`. This is the inverse of the matrix generated by `f2c_matrix`.
Args: Args:
cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]],
where lengths are in Angstroms and angles are in degrees. where lengths are in Angstroms and angles are in degrees.
Returns: Returns:
The 3x3 transformation matrix, or None if cell parameters are invalid. The 3x3 transformation matrix, or None if cell parameters are invalid.
""" """
f2c = f2c_matrix(cell_params) f2c = f2c_matrix(cell_params)
if f2c is None: if f2c is None:
return None return None
try: try:
return np.linalg.inv(f2c) return np.linalg.inv(f2c)
except np.linalg.LinAlgError: except np.linalg.LinAlgError:
return None return None
def apply_SYMM( def apply_SYMM(
frac_xyz: NDArrayFloat, symm_ops: SymmetryOperations frac_xyz: NDArrayFloat, symm_ops: SymmetryOperations
) -> NDArrayFloat: ) -> NDArrayFloat:
"""Applies symmetry operations to a single set of fractional coordinates. """Applies symmetry operations to a single set of fractional coordinates.
Args: Args:
frac_xyz: A NumPy array of fractional coordinates [x, y, z]. frac_xyz: A NumPy array of fractional coordinates [x, y, z].
symm_ops: A list of symmetry operation strings. symm_ops: A list of symmetry operation strings.
Returns: Returns:
A NumPy array of all symmetrically equivalent fractional coordinates. A NumPy array of all symmetrically equivalent fractional coordinates.
""" """
rot_matrices, trans_vectors = _parse_symmetry_operations(symm_ops) rot_matrices, trans_vectors = _parse_symmetry_operations(symm_ops)
equivalent_positions = [ equivalent_positions = [
np.dot(frac_xyz, rot.T) + trans.squeeze() np.dot(frac_xyz, rot.T) + trans.squeeze()
for rot, trans in zip(rot_matrices, trans_vectors) for rot, trans in zip(rot_matrices, trans_vectors)
] ]
return np.array(equivalent_positions) return np.array(equivalent_positions)
def apply_SYMM_with_element( def apply_SYMM_with_element(
elements: Union[str, List[str]], elements: Union[str, List[str]],
frac_xyzs: NDArrayFloat, frac_xyzs: NDArrayFloat,
symm_ops: SymmetryOperations, symm_ops: SymmetryOperations,
) -> Tuple[NDArrayFloat, NDArrayFloat]: ) -> Tuple[NDArrayFloat, NDArrayFloat]:
"""Applies symmetry operations, returning new elements and coordinates. """Applies symmetry operations, returning new elements and coordinates.
Args: Args:
elements: The element symbol(s) corresponding to the coordinates. elements: The element symbol(s) corresponding to the coordinates.
frac_xyzs: A NumPy array of fractional coordinates. frac_xyzs: A NumPy array of fractional coordinates.
symm_ops: A list of symmetry operation strings. symm_ops: A list of symmetry operation strings.
Returns: Returns:
A tuple containing: A tuple containing:
- A NumPy array of element symbols for each new position. - A NumPy array of element symbols for each new position.
- A NumPy array of all symmetrically equivalent fractional coordinates. - A NumPy array of all symmetrically equivalent fractional coordinates.
""" """
equivalent_positions = apply_SYMM(frac_xyzs, symm_ops) equivalent_positions = apply_SYMM(frac_xyzs, symm_ops)
num_ops = len(equivalent_positions) num_ops = len(equivalent_positions)
replicated_elements = np.tile(np.array(elements).squeeze(), (num_ops, 1)) replicated_elements = np.tile(np.array(elements).squeeze(), (num_ops, 1))
return replicated_elements, equivalent_positions return replicated_elements, equivalent_positions
def calculate_longest_diagonal_length(cell_vect: CellVectors) -> float: def calculate_longest_diagonal_length(cell_vect: CellVectors) -> float:
"""Calculates the length of the longest space diagonal of a unit cell. """Calculates the length of the longest space diagonal of a unit cell.
The longest diagonal connects the origin (0,0,0) to the opposite The longest diagonal connects the origin (0,0,0) to the opposite
corner (1,1,1) of the unit cell. corner (1,1,1) of the unit cell.
Args: Args:
cell_vect: The three lattice vectors of the cell. cell_vect: The three lattice vectors of the cell.
Returns: Returns:
The length of the longest diagonal in Angstroms. The length of the longest diagonal in Angstroms.
""" """
cell_vect_np = np.array(cell_vect) cell_vect_np = np.array(cell_vect)
diagonal_vector = np.sum(cell_vect_np, axis=0) diagonal_vector = np.sum(cell_vect_np, axis=0)
return float(np.linalg.norm(diagonal_vector)) return float(np.linalg.norm(diagonal_vector))
def calculate_distance_of_parallel_plane_in_crystal(cell_vect: CellVectors) -> List[float]: def calculate_distance_of_parallel_plane_in_crystal(cell_vect: CellVectors) -> List[float]:
"""Calculates inter-planar distances for primary crystallographic planes. """Calculates inter-planar distances for primary crystallographic planes.
This computes the distances for the (100), (010), and (001) families of planes. This computes the distances for the (100), (010), and (001) families of planes.
Args: Args:
cell_vect: The three lattice vectors [a, b, c] of the cell. cell_vect: The three lattice vectors [a, b, c] of the cell.
Returns: Returns:
A list of three distances [d_a, d_b, d_c], where d_a is the distance A list of three distances [d_a, d_b, d_c], where d_a is the distance
between planes parallel to the b-c plane, and so on. between planes parallel to the b-c plane, and so on.
""" """
distances = [] distances = []
vectors = [np.array(v) for v in cell_vect] vectors = [np.array(v) for v in cell_vect]
# Permutations to calculate distance for each primary plane # Permutations to calculate distance for each primary plane
# (a to b-c plane, b to a-c plane, c to a-b plane) # (a to b-c plane, b to a-c plane, c to a-b plane)
indices = [(0, 1, 2), (1, 0, 2), (2, 0, 1)] indices = [(0, 1, 2), (1, 0, 2), (2, 0, 1)]
for i, j, k in indices: for i, j, k in indices:
point_p = vectors[i] point_p = vectors[i]
plane_v1 = vectors[j] plane_v1 = vectors[j]
plane_v2 = vectors[k] plane_v2 = vectors[k]
# Normal vector to the plane defined by plane_v1 and plane_v2 # Normal vector to the plane defined by plane_v1 and plane_v2
normal_vector = np.cross(plane_v1, plane_v2) normal_vector = np.cross(plane_v1, plane_v2)
# Distance from point P to the plane is |N · P| / ||N|| # Distance from point P to the plane is |N · P| / ||N||
distance = abs(np.dot(normal_vector, point_p)) / np.linalg.norm(normal_vector) distance = abs(np.dot(normal_vector, point_p)) / np.linalg.norm(normal_vector)
distances.append(distance) distances.append(distance)
return distances return distances
def detect_is_frame_vdw_new(crystal: "data_classes.Crystal", tolerance: float = 1.2) -> bool: def detect_is_frame_vdw_new(crystal: "data_classes.Crystal", tolerance: float = 1.2) -> bool:
"""Detects if a crystal structure forms a connected framework via VdW radii. """Detects if a crystal structure forms a connected framework via VdW radii.
The method involves: The method involves:
1. Expanding the crystal to a P1 symmetry supercell. 1. Expanding the crystal to a P1 symmetry supercell.
2. Building a 3x3x3 supercell to ensure periodic connections are considered. 2. Building a 3x3x3 supercell to ensure periodic connections are considered.
3. Constructing a graph where atoms are nodes and an edge exists if their 3. Constructing a graph where atoms are nodes and an edge exists if their
distance is within a scaled sum of their van der Waals radii. distance is within a scaled sum of their van der Waals radii.
4. Checking if the largest connected component in the graph is large enough 4. Checking if the largest connected component in the graph is large enough
to be considered a single, percolating framework. to be considered a single, percolating framework.
Args: Args:
crystal: The Crystal object to analyze. crystal: The Crystal object to analyze.
tolerance: A tolerance factor to scale the VdW radii sum. tolerance: A tolerance factor to scale the VdW radii sum.
Returns: Returns:
True if the structure is a connected framework, False otherwise. True if the structure is a connected framework, False otherwise.
""" """
crystal_temp = copy.deepcopy(crystal) crystal_temp = copy.deepcopy(crystal)
crystal_temp.make_p1() crystal_temp.make_p1()
crystal_temp.move_atom_into_cell() crystal_temp.move_atom_into_cell()
# Create a 3x3x3 supercell to check for connectivity across boundaries # Create a 3x3x3 supercell to check for connectivity across boundaries
crystal_supercell = super_cell(crystal_temp, cell_range=[[-1, 1], [-1, 1], [-1, 1]]) crystal_supercell = super_cell(crystal_temp, cell_range=[[-1, 1], [-1, 1], [-1, 1]])
all_ele, all_carts = crystal_supercell.get_ele_and_cart() all_ele, all_carts = crystal_supercell.get_ele_and_cart()
vdw_radii_map = chemical_knowledge.element_vdw_radii vdw_radii_map = chemical_knowledge.element_vdw_radii
vdw_max = max(vdw_radii_map[el] for el in set(all_ele)) vdw_max = max(vdw_radii_map[el] for el in set(all_ele))
distance_threshold = vdw_max * tolerance * 2 distance_threshold = vdw_max * tolerance * 2
# KDTree for efficient nearest-neighbor search # KDTree for efficient nearest-neighbor search
tree = KDTree(all_carts) tree = KDTree(all_carts)
pairs = tree.query_pairs(r=distance_threshold) pairs = tree.query_pairs(r=distance_threshold)
# Build a graph to find connected components # Build a graph to find connected components
graph = nx.Graph() graph = nx.Graph()
graph.add_nodes_from(range(len(all_carts))) graph.add_nodes_from(range(len(all_carts)))
graph.add_edges_from(list(pairs)) graph.add_edges_from(list(pairs))
if not graph.nodes: if not graph.nodes:
return False return False
# Find the largest connected component # Find the largest connected component
largest_cc = max(nx.connected_components(graph), key=len) largest_cc = max(nx.connected_components(graph), key=len)
# A heuristic to check for a percolating framework. A connected framework # A heuristic to check for a percolating framework. A connected framework
# should connect most atoms. The threshold '9' is empirical but robustly # should connect most atoms. The threshold '9' is empirical but robustly
# distinguishes between isolated molecules and a fully connected lattice. # distinguishes between isolated molecules and a fully connected lattice.
# In a 3x3x3 supercell (27 unit cells), a connected framework should involve # In a 3x3x3 supercell (27 unit cells), a connected framework should involve
# significantly more atoms than in a few unit cells. # significantly more atoms than in a few unit cells.
return len(largest_cc) > 9 * len(crystal_temp.atoms) return len(largest_cc) > 9 * len(crystal_temp.atoms)
\ No newline at end of file
from basic_function import format_parser from basic_function import format_parser
from basic_function import CSP_generator_normal from basic_function import CSP_generator_normal
import os import os
import concurrent.futures import concurrent.futures
import sys import sys
def process_crystal(seed, sg, molecules,output_path,add_name): def process_crystal(seed, sg, molecules,output_path,add_name):
aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg) aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg)
molecules_number = sum(len(molecule.atoms) for molecule in molecules) molecules_number = sum(len(molecule.atoms) for molecule in molecules)
new_crystal = aaa.generate(seed=seed) new_crystal = aaa.generate(seed=seed)
sys.stdout.flush() sys.stdout.flush()
if new_crystal is not None: if new_crystal is not None:
cif_out = format_parser.write_cif_file(new_crystal) cif_out = format_parser.write_cif_file(new_crystal)
with open(f"{output_path}/structures/{add_name}_{sg}_{seed}_z{len(molecules)}_{molecules_number}.cif", 'w') as target: with open(f"{output_path}/structures/{add_name}_{sg}_{seed}_z{len(molecules)}_{molecules_number}.cif", 'w') as target:
target.writelines(cif_out) target.writelines(cif_out)
return True return True
return False return False
def CSP_generater_parallel(molecules,output_path,need_structure = 100, space_group_list=[1],max_workers=8,add_name='',start_seed=1): def CSP_generater_parallel(molecules,output_path,need_structure = 100, space_group_list=[1],max_workers=8,add_name='',start_seed=1):
space_groups = space_group_list space_groups = space_group_list
accept_count = need_structure accept_count = need_structure
try: try:
os.makedirs("{}/structures".format(output_path)) os.makedirs("{}/structures".format(output_path))
except: except:
print("Warning, these is already an structures folder in this path, skip mkdir") print("Warning, these is already an structures folder in this path, skip mkdir")
for sg in space_groups: for sg in space_groups:
accept = 0 accept = 0
seed = start_seed seed = start_seed
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {} futures = {}
while accept < accept_count: while accept < accept_count:
# submit new task # submit new task
while len(futures) < max_workers and accept + len(futures) < accept_count: while len(futures) < max_workers and accept + len(futures) < accept_count:
future = executor.submit(process_crystal, seed, sg, molecules, output_path,add_name) future = executor.submit(process_crystal, seed, sg, molecules, output_path,add_name)
futures[future] = seed futures[future] = seed
seed += 1 seed += 1
# check the finished task # check the finished task
done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
for future in done: for future in done:
if future.result(): if future.result():
accept += 1 accept += 1
# remove it from list, no matter what result it is # remove it from list, no matter what result it is
del futures[future] del futures[future]
# cancel all task if the number need is arrived. # cancel all task if the number need is arrived.
if accept >= accept_count: if accept >= accept_count:
for future in futures: for future in futures:
future.cancel() future.cancel()
break break
def CSP_generater_serial(molecules,output_path,need_structure = 100, densely_pack_method=False, space_group_list=[1]): def CSP_generater_serial(molecules,output_path,need_structure = 100, densely_pack_method=False, space_group_list=[1]):
""" """
:param molecules: a list [molecule1, molecule2, ...] :param molecules: a list [molecule1, molecule2, ...]
:param output_path: a str indicate the path of output folder :param output_path: a str indicate the path of output folder
:param need_structure: int :param need_structure: int
:param space_group_list:a list indicate the space group need to search :param space_group_list:a list indicate the space group need to search
""" """
try: try:
os.makedirs("{}\\structures".format(output_path)) os.makedirs("{}\\structures".format(output_path))
except: except:
print("Warning, these is already an structures folder in this path, skip mkdir") print("Warning, these is already an structures folder in this path, skip mkdir")
for sg in space_group_list: for sg in space_group_list:
aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg) aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg)
accept=0 accept=0
i=1 i=1
while accept<need_structure: while accept<need_structure:
new_crystal = aaa.generate(seed=i,densely_pack_method=densely_pack_method) new_crystal = aaa.generate(seed=i,densely_pack_method=densely_pack_method)
if new_crystal==None: if new_crystal==None:
i += 1 i += 1
continue continue
cif_out = format_parser.write_cif_file(new_crystal) cif_out = format_parser.write_cif_file(new_crystal)
target = open("{}\\structures\\{}_{}.cif".format(output_path,sg,i), 'w') target = open("{}\\structures\\{}_{}.cif".format(output_path,sg,i), 'w')
target.writelines(cif_out) target.writelines(cif_out)
target.close() target.close()
accept+=1 accept+=1
i += 1 i += 1
def CSP_generater_one_test(molecules,output_path,space_group=1,seed=1): def CSP_generater_one_test(molecules,output_path,space_group=1,seed=1):
""" """
used to test the generator for a given space group and seed used to test the generator for a given space group and seed
:param molecules: a list [molecule1, molecule2, ...] :param molecules: a list [molecule1, molecule2, ...]
:param output_path: a str indicate the path of output folder :param output_path: a str indicate the path of output folder
:param space_group: a list indicate the space group need to search :param space_group: a list indicate the space group need to search
:param seed: int :param seed: int
:return: write out a cif :return: write out a cif
""" """
aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=space_group) aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=space_group)
new_crystal = aaa.generate(seed=seed, test=True) new_crystal = aaa.generate(seed=seed, test=True)
if new_crystal==None: if new_crystal==None:
print("Return failed generate") print("Return failed generate")
else: else:
cif_out = format_parser.write_cif_file(new_crystal,sym=True) cif_out = format_parser.write_cif_file(new_crystal,sym=True)
target = open("{}\\{}_{}.cif".format(output_path,space_group,seed), 'w') target = open("{}\\{}_{}.cif".format(output_path,space_group,seed), 'w')
target.writelines(cif_out) target.writelines(cif_out)
target.close() target.close()
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Provides functions for converting between different representations of a Provides functions for converting between different representations of a
crystallographic unit cell (cell parameters and lattice vectors) and for crystallographic unit cell (cell parameters and lattice vectors) and for
transforming atomic coordinates between fractional and Cartesian systems. transforming atomic coordinates between fractional and Cartesian systems.
""" """
# --- Standard Library Imports --- # --- Standard Library Imports ---
from typing import List, Tuple, Union from typing import List, Tuple, Union
# --- Third-Party Imports --- # --- Third-Party Imports ---
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
# --- Type Aliases for Clarity --- # --- Type Aliases for Clarity ---
NDArrayFloat = npt.NDArray[np.float64] NDArrayFloat = npt.NDArray[np.float64]
CellParameters = Tuple[List[float], List[float]] CellParameters = Tuple[List[float], List[float]]
CellVectors = Union[List[List[float]], NDArrayFloat] CellVectors = Union[List[List[float]], NDArrayFloat]
Coordinates = Union[List[float], NDArrayFloat] Coordinates = Union[List[float], NDArrayFloat]
def cell_para_to_vect( def cell_para_to_vect(
cell_para: CellParameters, check: bool = False cell_para: CellParameters, check: bool = False
) -> CellVectors: ) -> CellVectors:
"""Converts cell parameters to lattice vectors. """Converts cell parameters to lattice vectors.
The lattice vector `a` is aligned with the x-axis. The vector `b` lies in The lattice vector `a` is aligned with the x-axis. The vector `b` lies in
the xy-plane. the xy-plane.
Args: Args:
cell_para: A tuple containing [[a, b, c], [alpha, beta, gamma]], cell_para: A tuple containing [[a, b, c], [alpha, beta, gamma]],
where lengths are in Angstroms and angles are in degrees. where lengths are in Angstroms and angles are in degrees.
check: If True, asserts the input shape is correct. check: If True, asserts the input shape is correct.
Returns: Returns:
A 3x3 list of lists representing the cell vectors [a, b, c]. A 3x3 list of lists representing the cell vectors [a, b, c].
""" """
if check: if check:
shape_check = np.array(cell_para) shape_check = np.array(cell_para)
assert shape_check.shape == (2, 3), "Input `cell_para` must have shape (2, 3)." assert shape_check.shape == (2, 3), "Input `cell_para` must have shape (2, 3)."
lengths = cell_para[0] lengths = cell_para[0]
angles_deg = cell_para[1] angles_deg = cell_para[1]
a, b, c = lengths a, b, c = lengths
alpha, beta, gamma = np.deg2rad(angles_deg) alpha, beta, gamma = np.deg2rad(angles_deg)
cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma])
sin_g = np.sin(gamma) sin_g = np.sin(gamma)
# This term is related to the square of the cell volume. # This term is related to the square of the cell volume.
# It ensures the cell parameters are physically valid. # It ensures the cell parameters are physically valid.
volume_term_sq = ( volume_term_sq = (
1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g
) )
# Ensure the argument for sqrt is non-negative # Ensure the argument for sqrt is non-negative
volume_term = np.sqrt(max(0, volume_term_sq)) volume_term = np.sqrt(max(0, volume_term_sq))
cell_vect = np.zeros((3, 3)) cell_vect = np.zeros((3, 3))
cell_vect[0, 0] = a cell_vect[0, 0] = a
cell_vect[1, 0] = b * cos_g cell_vect[1, 0] = b * cos_g
cell_vect[1, 1] = b * sin_g cell_vect[1, 1] = b * sin_g
cell_vect[2, 0] = c * cos_b cell_vect[2, 0] = c * cos_b
cell_vect[2, 1] = c * (cos_a - cos_b * cos_g) / sin_g cell_vect[2, 1] = c * (cos_a - cos_b * cos_g) / sin_g
cell_vect[2, 2] = c * volume_term / sin_g cell_vect[2, 2] = c * volume_term / sin_g
return cell_vect.tolist() return cell_vect.tolist()
def cell_vect_to_para(cell_vect: CellVectors, check: bool = False) -> CellParameters: def cell_vect_to_para(cell_vect: CellVectors, check: bool = False) -> CellParameters:
"""Converts lattice vectors to cell parameters. """Converts lattice vectors to cell parameters.
Args: Args:
cell_vect: A 3x3 array-like object representing the lattice vectors. cell_vect: A 3x3 array-like object representing the lattice vectors.
check: If True, asserts the input shape is correct. check: If True, asserts the input shape is correct.
Returns: Returns:
A tuple containing [[a, b, c], [alpha, beta, gamma]]. A tuple containing [[a, b, c], [alpha, beta, gamma]].
""" """
cell_vect_np = np.array(cell_vect) cell_vect_np = np.array(cell_vect)
if check: if check:
assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)."
vec_a, vec_b, vec_c = cell_vect_np vec_a, vec_b, vec_c = cell_vect_np
len_a = np.linalg.norm(vec_a) len_a = np.linalg.norm(vec_a)
len_b = np.linalg.norm(vec_b) len_b = np.linalg.norm(vec_b)
len_c = np.linalg.norm(vec_c) len_c = np.linalg.norm(vec_c)
lengths = [len_a, len_b, len_c] lengths = [len_a, len_b, len_c]
# Calculate angles using the dot product formula; handle potential floating point inaccuracies. # Calculate angles using the dot product formula; handle potential floating point inaccuracies.
def _calculate_angle(v1, v2, norm1, norm2): def _calculate_angle(v1, v2, norm1, norm2):
cosine_angle = np.dot(v1, v2) / (norm1 * norm2) cosine_angle = np.dot(v1, v2) / (norm1 * norm2)
# Clip to handle values slightly outside [-1, 1] due to precision issues # Clip to handle values slightly outside [-1, 1] due to precision issues
return np.arccos(np.clip(cosine_angle, -1.0, 1.0)) return np.arccos(np.clip(cosine_angle, -1.0, 1.0))
alpha_rad = _calculate_angle(vec_b, vec_c, len_b, len_c) alpha_rad = _calculate_angle(vec_b, vec_c, len_b, len_c)
beta_rad = _calculate_angle(vec_a, vec_c, len_a, len_c) beta_rad = _calculate_angle(vec_a, vec_c, len_a, len_c)
gamma_rad = _calculate_angle(vec_a, vec_b, len_a, len_b) gamma_rad = _calculate_angle(vec_a, vec_b, len_a, len_b)
angles_deg = np.rad2deg([alpha_rad, beta_rad, gamma_rad]).tolist() angles_deg = np.rad2deg([alpha_rad, beta_rad, gamma_rad]).tolist()
return (lengths, angles_deg) return (lengths, angles_deg)
def atom_frac_to_cart_by_cell_vect( def atom_frac_to_cart_by_cell_vect(
atom_frac: Coordinates, cell_vect: CellVectors, check: bool = False atom_frac: Coordinates, cell_vect: CellVectors, check: bool = False
) -> List[float]: ) -> List[float]:
"""Converts fractional coordinates to Cartesian coordinates using cell vectors. """Converts fractional coordinates to Cartesian coordinates using cell vectors.
Args: Args:
atom_frac: A 3-element list or array of fractional coordinates. atom_frac: A 3-element list or array of fractional coordinates.
cell_vect: A 3x3 matrix of lattice vectors. cell_vect: A 3x3 matrix of lattice vectors.
check: If True, asserts input shapes are correct. check: If True, asserts input shapes are correct.
Returns: Returns:
A list of 3 Cartesian coordinates. A list of 3 Cartesian coordinates.
""" """
atom_frac_np = np.array(atom_frac) atom_frac_np = np.array(atom_frac)
cell_vect_np = np.array(cell_vect) cell_vect_np = np.array(cell_vect)
if check: if check:
assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)."
assert atom_frac_np.shape == (3,), "Input `atom_frac` must have 3 elements." assert atom_frac_np.shape == (3,), "Input `atom_frac` must have 3 elements."
# The transformation is a linear combination of the basis vectors. # The transformation is a linear combination of the basis vectors.
# atom_cart = frac_x * vec_a + frac_y * vec_b + frac_z * vec_c # atom_cart = frac_x * vec_a + frac_y * vec_b + frac_z * vec_c
# This is equivalent to a dot product: [fx, fy, fz] @ [[ax,ay,az],[bx,by,bz],[cx,cy,cz]] # This is equivalent to a dot product: [fx, fy, fz] @ [[ax,ay,az],[bx,by,bz],[cx,cy,cz]]
atom_cart = np.dot(atom_frac_np, cell_vect_np) atom_cart = np.dot(atom_frac_np, cell_vect_np)
return atom_cart.tolist() return atom_cart.tolist()
def atom_frac_to_cart_by_cell_para( def atom_frac_to_cart_by_cell_para(
atom_frac: Coordinates, cell_para: CellParameters, check: bool = False atom_frac: Coordinates, cell_para: CellParameters, check: bool = False
) -> List[float]: ) -> List[float]:
"""Converts fractional coordinates to Cartesian using cell parameters. """Converts fractional coordinates to Cartesian using cell parameters.
Args: Args:
atom_frac: A 3-element list or array of fractional coordinates. atom_frac: A 3-element list or array of fractional coordinates.
cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]].
check: If True, performs validation checks in underlying functions. check: If True, performs validation checks in underlying functions.
Returns: Returns:
A list of 3 Cartesian coordinates. A list of 3 Cartesian coordinates.
""" """
cell_vect = cell_para_to_vect(cell_para, check=check) cell_vect = cell_para_to_vect(cell_para, check=check)
return atom_frac_to_cart_by_cell_vect(atom_frac, cell_vect, check=check) return atom_frac_to_cart_by_cell_vect(atom_frac, cell_vect, check=check)
def atom_cart_to_frac_by_cell_vect( def atom_cart_to_frac_by_cell_vect(
atom_cart: Coordinates, cell_vect: CellVectors, check: bool = False atom_cart: Coordinates, cell_vect: CellVectors, check: bool = False
) -> List[float]: ) -> List[float]:
"""Converts Cartesian coordinates to fractional coordinates using cell vectors. """Converts Cartesian coordinates to fractional coordinates using cell vectors.
Args: Args:
atom_cart: A 3-element list or array of Cartesian coordinates. atom_cart: A 3-element list or array of Cartesian coordinates.
cell_vect: A 3x3 matrix of lattice vectors. cell_vect: A 3x3 matrix of lattice vectors.
check: If True, asserts input shapes are correct. check: If True, asserts input shapes are correct.
Returns: Returns:
A list of 3 fractional coordinates. A list of 3 fractional coordinates.
""" """
atom_cart_np = np.array(atom_cart) atom_cart_np = np.array(atom_cart)
cell_vect_np = np.array(cell_vect) cell_vect_np = np.array(cell_vect)
if check: if check:
assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)."
assert atom_cart_np.shape == (3,), "Input `atom_cart` must have 3 elements." assert atom_cart_np.shape == (3,), "Input `atom_cart` must have 3 elements."
# The transformation is atom_frac = atom_cart @ inverse(cell_vect) # The transformation is atom_frac = atom_cart @ inverse(cell_vect)
inv_cell_vect = np.linalg.inv(cell_vect_np) inv_cell_vect = np.linalg.inv(cell_vect_np)
atom_frac = np.dot(atom_cart_np, inv_cell_vect) atom_frac = np.dot(atom_cart_np, inv_cell_vect)
return atom_frac.tolist() return atom_frac.tolist()
def atom_cart_to_frac_by_cell_para( def atom_cart_to_frac_by_cell_para(
atom_cart: Coordinates, cell_para: CellParameters, check: bool = False atom_cart: Coordinates, cell_para: CellParameters, check: bool = False
) -> List[float]: ) -> List[float]:
"""Converts Cartesian coordinates to fractional using cell parameters. """Converts Cartesian coordinates to fractional using cell parameters.
Args: Args:
atom_cart: A 3-element list or array of Cartesian coordinates. atom_cart: A 3-element list or array of Cartesian coordinates.
cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]].
check: If True, performs validation checks in underlying functions. check: If True, performs validation checks in underlying functions.
Returns: Returns:
A list of 3 fractional coordinates. A list of 3 fractional coordinates.
""" """
cell_vect = cell_para_to_vect(cell_para, check=check) cell_vect = cell_para_to_vect(cell_para, check=check)
return atom_cart_to_frac_by_cell_vect(atom_cart, cell_vect, check=check) return atom_cart_to_frac_by_cell_vect(atom_cart, cell_vect, check=check)
def calculate_volume(cell_info: Union[CellParameters, CellVectors]) -> float: def calculate_volume(cell_info: Union[CellParameters, CellVectors]) -> float:
"""Calculates the volume of the unit cell. """Calculates the volume of the unit cell.
Args: Args:
cell_info: Can be either cell parameters [[a,b,c], [al,be,ga]] or cell_info: Can be either cell parameters [[a,b,c], [al,be,ga]] or
a 3x3 matrix of cell vectors. a 3x3 matrix of cell vectors.
Returns: Returns:
The volume of the cell in cubic Angstroms. The volume of the cell in cubic Angstroms.
Raises: Raises:
ValueError: If the shape of `cell_info` is not (2, 3) or (3, 3). ValueError: If the shape of `cell_info` is not (2, 3) or (3, 3).
""" """
cell_info_np = np.array(cell_info) cell_info_np = np.array(cell_info)
if cell_info_np.shape == (3, 3): if cell_info_np.shape == (3, 3):
# Input is cell vectors, calculate volume using the scalar triple product. # Input is cell vectors, calculate volume using the scalar triple product.
return float(np.abs(np.dot(cell_info_np[0], np.cross(cell_info_np[1], cell_info_np[2])))) return float(np.abs(np.dot(cell_info_np[0], np.cross(cell_info_np[1], cell_info_np[2]))))
elif cell_info_np.shape == (2, 3): elif cell_info_np.shape == (2, 3):
# Input is cell parameters. # Input is cell parameters.
lengths, angles_deg = cell_info_np lengths, angles_deg = cell_info_np
a, b, c = lengths a, b, c = lengths
alpha, beta, gamma = np.deg2rad(angles_deg) alpha, beta, gamma = np.deg2rad(angles_deg)
cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma])
# Standard formula for volume from cell parameters # Standard formula for volume from cell parameters
volume_sq = ( volume_sq = (
a**2 * b**2 * c**2 * (1 - cos_a**2 - cos_b**2 - cos_g**2 + 2 * cos_a * cos_b * cos_g) a**2 * b**2 * c**2 * (1 - cos_a**2 - cos_b**2 - cos_g**2 + 2 * cos_a * cos_b * cos_g)
) )
return float(np.sqrt(max(0, volume_sq))) return float(np.sqrt(max(0, volume_sq)))
else: else:
raise ValueError(f"Cannot understand input shape {cell_info_np.shape} for `cell_info`.") raise ValueError(f"Cannot understand input shape {cell_info_np.shape} for `cell_info`.")
\ No newline at end of file
#!/bin/bash
TOP_DIR=$(pwd) TOP_DIR=$(pwd)
TAR_DIR="${TOP_DIR}/test" TAR_DIR="${TOP_DIR}/test"
mkdir -p "${TAR_DIR}" mkdir -p "${TAR_DIR}"
cd ${TAR_DIR} cd ${TAR_DIR}
# generate structures # conformer search and structure generation
python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ # change --mode to conformer_only or structure_only to seperate the process.
--molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \
--num_generation 100 --generate_conformers 20 --use_conformers 4 > generate.log 2>&1 --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\
--num_generation 100 --generate_conformers 20 --use_conformers 4 --mode all > generate.log 2>&1
# opt structures using mace, --batch_size 0 means auto batch size only for mace
mkdir -p "${TAR_DIR}/mace_opt" # opt structures using mace, --batch_size 0 means auto batch size only for mace
cd "${TAR_DIR}/mace_opt" mkdir -p "${TAR_DIR}/mace_opt"
python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ cd "${TAR_DIR}/mace_opt"
--molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 80 --batch_size 0 \ python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \
--max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 80 --batch_size 0 \
--optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 1 --cueq true \ --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \
--use_ordered_files true --model mace > opt.log 2>&1 --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 1 --cueq true \
--use_ordered_files true --model mace > opt.log 2>&1
# opt structures using 7net
# mkdir -p "${TAR_DIR}/7net_opt" # opt structures using 7net
# cd "${TAR_DIR}/7net_opt" # mkdir -p "${TAR_DIR}/7net_opt"
# python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ # cd "${TAR_DIR}/7net_opt"
# --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 48 --batch_size 2 \ # python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \
# --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ # --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 48 --batch_size 2 \
# --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true \ # --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \
# --use_ordered_files true --model sevennet > opt.log 2>&1 # --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true \
# --use_ordered_files true --model sevennet > opt.log 2>&1
# Postprocess the opt structures
python "${TOP_DIR}/post_process/clean_table.py" # Postprocess the opt structures
## Make sure you have installed csd-python-api in current env before execuing following commands python "${TOP_DIR}/post_process/clean_table.py"
# conda activate ccdc ## Make sure you have installed csd-python-api in current env before execuing following commands
# python "${TOP_DIR}/post_process/check_match.py" --workers 80 --timeout 20 --ref_path "${TAR_DIR}/refs" # conda activate ccdc
# python "${TOP_DIR}/post_process/duplicate_remove.py" --workers 80 # python "${TOP_DIR}/post_process/check_match.py" --workers 80 --timeout 20 --ref_path "${TAR_DIR}/refs"
# python "${TOP_DIR}/post_process/duplicate_remove.py" --workers 80
from importlib.metadata import version from importlib.metadata import version
from packaging.version import Version from packaging.version import Version
__version__ = version('sevenn') __version__ = version('sevenn')
from e3nn import __version__ as e3nn_ver from e3nn import __version__ as e3nn_ver
if Version(e3nn_ver) < Version('0.5.0'): if Version(e3nn_ver) < Version('0.5.0'):
raise ValueError( raise ValueError(
'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient ' 'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient '
'convention.' 'convention.'
) )
import os import os
from enum import Enum from enum import Enum
from typing import Dict from typing import Dict
import torch import torch
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn.nn.activation import ShiftedSoftPlus from sevenn.nn.activation import ShiftedSoftPlus
NUM_UNIV_ELEMENT = 119 # Z = 0 ~ 118 NUM_UNIV_ELEMENT = 119 # Z = 0 ~ 118
IMPLEMENTED_RADIAL_BASIS = ['bessel'] IMPLEMENTED_RADIAL_BASIS = ['bessel']
IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR'] IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR']
# TODO: support None. This became difficult because of parallel model # TODO: support None. This became difficult because of parallel model
IMPLEMENTED_SELF_CONNECTION_TYPE = ['nequip', 'linear'] IMPLEMENTED_SELF_CONNECTION_TYPE = ['nequip', 'linear']
IMPLEMENTED_INTERACTION_TYPE = ['nequip'] IMPLEMENTED_INTERACTION_TYPE = ['nequip']
IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies'] IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies']
IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms'] IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms']
SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss'] SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss']
SUPPORTING_ERROR_TYPES = [ SUPPORTING_ERROR_TYPES = [
'TotalEnergy', 'TotalEnergy',
'Energy', 'Energy',
'Force', 'Force',
'Stress', 'Stress',
'Stress_GPa', 'Stress_GPa',
'TotalLoss', 'TotalLoss',
] ]
IMPLEMENTED_MODEL = ['E3_equivariant_model'] IMPLEMENTED_MODEL = ['E3_equivariant_model']
# string input to real torch function # string input to real torch function
ACTIVATION = { ACTIVATION = {
'relu': torch.nn.functional.relu, 'relu': torch.nn.functional.relu,
'silu': torch.nn.functional.silu, 'silu': torch.nn.functional.silu,
'tanh': torch.tanh, 'tanh': torch.tanh,
'abs': torch.abs, 'abs': torch.abs,
'ssp': ShiftedSoftPlus, 'ssp': ShiftedSoftPlus,
'sigmoid': torch.sigmoid, 'sigmoid': torch.sigmoid,
'elu': torch.nn.functional.elu, 'elu': torch.nn.functional.elu,
} }
ACTIVATION_FOR_EVEN = { ACTIVATION_FOR_EVEN = {
'ssp': ShiftedSoftPlus, 'ssp': ShiftedSoftPlus,
'silu': torch.nn.functional.silu, 'silu': torch.nn.functional.silu,
} }
ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs} ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs}
ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD} ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD}
_prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials') _prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials')
SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth' SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth'
SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth' SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth'
SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth' SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth'
SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth' SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth'
SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth' SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth'
SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth' SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth'
_git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download' _git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download'
CHECKPOINT_DOWNLOAD_LINKS = { CHECKPOINT_DOWNLOAD_LINKS = {
SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth', SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth',
SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth', SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth',
} }
# to avoid torch script to compile torch_geometry.data # to avoid torch script to compile torch_geometry.data
AtomGraphDataType = Dict[str, torch.Tensor] AtomGraphDataType = Dict[str, torch.Tensor]
class LossType(Enum): # only used for train_v1, do not use it afterwards class LossType(Enum): # only used for train_v1, do not use it afterwards
ENERGY = 'energy' # eV or eV/atom ENERGY = 'energy' # eV or eV/atom
FORCE = 'force' # eV/A FORCE = 'force' # eV/A
STRESS = 'stress' # kB STRESS = 'stress' # kB
def error_record_condition(x): def error_record_condition(x):
if type(x) is not list: if type(x) is not list:
return False return False
for v in x: for v in x:
if type(v) is not list or len(v) != 2: if type(v) is not list or len(v) != 2:
return False return False
if v[0] not in SUPPORTING_ERROR_TYPES: if v[0] not in SUPPORTING_ERROR_TYPES:
return False return False
if v[0] == 'TotalLoss': if v[0] == 'TotalLoss':
continue continue
if v[1] not in SUPPORTING_METRICS: if v[1] not in SUPPORTING_METRICS:
return False return False
return True return True
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = { DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = {
KEY.CUTOFF: 4.5, KEY.CUTOFF: 4.5,
KEY.NODE_FEATURE_MULTIPLICITY: 32, KEY.NODE_FEATURE_MULTIPLICITY: 32,
KEY.IRREPS_MANUAL: False, KEY.IRREPS_MANUAL: False,
KEY.LMAX: 1, KEY.LMAX: 1,
KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax
KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax
KEY.IS_PARITY: True, KEY.IS_PARITY: True,
KEY.NUM_CONVOLUTION: 3, KEY.NUM_CONVOLUTION: 3,
KEY.RADIAL_BASIS: { KEY.RADIAL_BASIS: {
KEY.RADIAL_BASIS_NAME: 'bessel', KEY.RADIAL_BASIS_NAME: 'bessel',
}, },
KEY.CUTOFF_FUNCTION: { KEY.CUTOFF_FUNCTION: {
KEY.CUTOFF_FUNCTION_NAME: 'poly_cut', KEY.CUTOFF_FUNCTION_NAME: 'poly_cut',
}, },
KEY.ACTIVATION_RADIAL: 'silu', KEY.ACTIVATION_RADIAL: 'silu',
KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'}, KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'},
KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'}, KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'},
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64], KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64],
# KEY.AVG_NUM_NEIGH: True, # deprecated # KEY.AVG_NUM_NEIGH: True, # deprecated
# KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated # KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated
KEY.CONV_DENOMINATOR: 'avg_num_neigh', KEY.CONV_DENOMINATOR: 'avg_num_neigh',
KEY.TRAIN_DENOMINTAOR: False, KEY.TRAIN_DENOMINTAOR: False,
KEY.TRAIN_SHIFT_SCALE: False, KEY.TRAIN_SHIFT_SCALE: False,
# KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True # KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True
KEY.USE_BIAS_IN_LINEAR: False, KEY.USE_BIAS_IN_LINEAR: False,
KEY.USE_MODAL_NODE_EMBEDDING: False, KEY.USE_MODAL_NODE_EMBEDDING: False,
KEY.USE_MODAL_SELF_INTER_INTRO: False, KEY.USE_MODAL_SELF_INTER_INTRO: False,
KEY.USE_MODAL_SELF_INTER_OUTRO: False, KEY.USE_MODAL_SELF_INTER_OUTRO: False,
KEY.USE_MODAL_OUTPUT_BLOCK: False, KEY.USE_MODAL_OUTPUT_BLOCK: False,
KEY.READOUT_AS_FCN: False, KEY.READOUT_AS_FCN: False,
# Applied af readout as fcn is True # Applied af readout as fcn is True
KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30], KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30],
KEY.READOUT_FCN_ACTIVATION: 'relu', KEY.READOUT_FCN_ACTIVATION: 'relu',
KEY.SELF_CONNECTION_TYPE: 'nequip', KEY.SELF_CONNECTION_TYPE: 'nequip',
KEY.INTERACTION_TYPE: 'nequip', KEY.INTERACTION_TYPE: 'nequip',
KEY._NORMALIZE_SPH: True, KEY._NORMALIZE_SPH: True,
KEY.CUEQUIVARIANCE_CONFIG: {}, KEY.CUEQUIVARIANCE_CONFIG: {},
} }
# Basically, "If provided, it should be type of ..." # Basically, "If provided, it should be type of ..."
MODEL_CONFIG_CONDITION = { MODEL_CONFIG_CONDITION = {
KEY.NODE_FEATURE_MULTIPLICITY: int, KEY.NODE_FEATURE_MULTIPLICITY: int,
KEY.LMAX: int, KEY.LMAX: int,
KEY.LMAX_EDGE: int, KEY.LMAX_EDGE: int,
KEY.LMAX_NODE: int, KEY.LMAX_NODE: int,
KEY.IS_PARITY: bool, KEY.IS_PARITY: bool,
KEY.RADIAL_BASIS: { KEY.RADIAL_BASIS: {
KEY.RADIAL_BASIS_NAME: lambda x: x in IMPLEMENTED_RADIAL_BASIS, KEY.RADIAL_BASIS_NAME: lambda x: x in IMPLEMENTED_RADIAL_BASIS,
}, },
KEY.CUTOFF_FUNCTION: { KEY.CUTOFF_FUNCTION: {
KEY.CUTOFF_FUNCTION_NAME: lambda x: x in IMPLEMENTED_CUTOFF_FUNCTION, KEY.CUTOFF_FUNCTION_NAME: lambda x: x in IMPLEMENTED_CUTOFF_FUNCTION,
}, },
KEY.CUTOFF: float, KEY.CUTOFF: float,
KEY.NUM_CONVOLUTION: int, KEY.NUM_CONVOLUTION: int,
KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float)
or x or x
in [ in [
'avg_num_neigh', 'avg_num_neigh',
'sqrt_avg_num_neigh', 'sqrt_avg_num_neigh',
], ],
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list, KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list,
KEY.TRAIN_SHIFT_SCALE: bool, KEY.TRAIN_SHIFT_SCALE: bool,
KEY.TRAIN_DENOMINTAOR: bool, KEY.TRAIN_DENOMINTAOR: bool,
KEY.USE_BIAS_IN_LINEAR: bool, KEY.USE_BIAS_IN_LINEAR: bool,
KEY.USE_MODAL_NODE_EMBEDDING: bool, KEY.USE_MODAL_NODE_EMBEDDING: bool,
KEY.USE_MODAL_SELF_INTER_INTRO: bool, KEY.USE_MODAL_SELF_INTER_INTRO: bool,
KEY.USE_MODAL_SELF_INTER_OUTRO: bool, KEY.USE_MODAL_SELF_INTER_OUTRO: bool,
KEY.USE_MODAL_OUTPUT_BLOCK: bool, KEY.USE_MODAL_OUTPUT_BLOCK: bool,
KEY.READOUT_AS_FCN: bool, KEY.READOUT_AS_FCN: bool,
KEY.READOUT_FCN_HIDDEN_NEURONS: list, KEY.READOUT_FCN_HIDDEN_NEURONS: list,
KEY.READOUT_FCN_ACTIVATION: str, KEY.READOUT_FCN_ACTIVATION: str,
KEY.ACTIVATION_RADIAL: str, KEY.ACTIVATION_RADIAL: str,
KEY.SELF_CONNECTION_TYPE: lambda x: ( KEY.SELF_CONNECTION_TYPE: lambda x: (
x in IMPLEMENTED_SELF_CONNECTION_TYPE x in IMPLEMENTED_SELF_CONNECTION_TYPE
or ( or (
isinstance(x, list) isinstance(x, list)
and all(sc in IMPLEMENTED_SELF_CONNECTION_TYPE for sc in x) and all(sc in IMPLEMENTED_SELF_CONNECTION_TYPE for sc in x)
) )
), ),
KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE, KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE,
KEY._NORMALIZE_SPH: bool, KEY._NORMALIZE_SPH: bool,
KEY.CUEQUIVARIANCE_CONFIG: dict, KEY.CUEQUIVARIANCE_CONFIG: dict,
} }
def model_defaults(config): def model_defaults(config):
defaults = DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG defaults = DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
if KEY.READOUT_AS_FCN not in config: if KEY.READOUT_AS_FCN not in config:
config[KEY.READOUT_AS_FCN] = defaults[KEY.READOUT_AS_FCN] config[KEY.READOUT_AS_FCN] = defaults[KEY.READOUT_AS_FCN]
if config[KEY.READOUT_AS_FCN] is False: if config[KEY.READOUT_AS_FCN] is False:
defaults.pop(KEY.READOUT_FCN_ACTIVATION, None) defaults.pop(KEY.READOUT_FCN_ACTIVATION, None)
defaults.pop(KEY.READOUT_FCN_HIDDEN_NEURONS, None) defaults.pop(KEY.READOUT_FCN_HIDDEN_NEURONS, None)
return defaults return defaults
DEFAULT_DATA_CONFIG = { DEFAULT_DATA_CONFIG = {
KEY.DTYPE: 'single', KEY.DTYPE: 'single',
KEY.DATA_FORMAT: 'ase', KEY.DATA_FORMAT: 'ase',
KEY.DATA_FORMAT_ARGS: {}, KEY.DATA_FORMAT_ARGS: {},
KEY.SAVE_DATASET: False, KEY.SAVE_DATASET: False,
KEY.SAVE_BY_LABEL: False, KEY.SAVE_BY_LABEL: False,
KEY.SAVE_BY_TRAIN_VALID: False, KEY.SAVE_BY_TRAIN_VALID: False,
KEY.RATIO: 0.0, KEY.RATIO: 0.0,
KEY.BATCH_SIZE: 6, KEY.BATCH_SIZE: 6,
KEY.PREPROCESS_NUM_CORES: 1, KEY.PREPROCESS_NUM_CORES: 1,
KEY.COMPUTE_STATISTICS: True, KEY.COMPUTE_STATISTICS: True,
KEY.DATASET_TYPE: 'graph', KEY.DATASET_TYPE: 'graph',
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: False, # KEY.USE_SPECIES_WISE_SHIFT_SCALE: False,
KEY.USE_MODAL_WISE_SHIFT: False, KEY.USE_MODAL_WISE_SHIFT: False,
KEY.USE_MODAL_WISE_SCALE: False, KEY.USE_MODAL_WISE_SCALE: False,
KEY.SHIFT: 'per_atom_energy_mean', KEY.SHIFT: 'per_atom_energy_mean',
KEY.SCALE: 'force_rms', KEY.SCALE: 'force_rms',
# KEY.DATA_SHUFFLE: True, # KEY.DATA_SHUFFLE: True,
# KEY.DATA_WEIGHT: False, # KEY.DATA_WEIGHT: False,
# KEY.DATA_MODALITY: False, # KEY.DATA_MODALITY: False,
} }
DATA_CONFIG_CONDITION = { DATA_CONFIG_CONDITION = {
KEY.DTYPE: str, KEY.DTYPE: str,
KEY.DATA_FORMAT: str, KEY.DATA_FORMAT: str,
KEY.DATA_FORMAT_ARGS: dict, KEY.DATA_FORMAT_ARGS: dict,
KEY.SAVE_DATASET: str, KEY.SAVE_DATASET: str,
KEY.SAVE_BY_LABEL: bool, KEY.SAVE_BY_LABEL: bool,
KEY.SAVE_BY_TRAIN_VALID: bool, KEY.SAVE_BY_TRAIN_VALID: bool,
KEY.RATIO: float, KEY.RATIO: float,
KEY.BATCH_SIZE: int, KEY.BATCH_SIZE: int,
KEY.PREPROCESS_NUM_CORES: int, KEY.PREPROCESS_NUM_CORES: int,
KEY.DATASET_TYPE: lambda x: x in ['graph', 'atoms'], KEY.DATASET_TYPE: lambda x: x in ['graph', 'atoms'],
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool, # KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool,
KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT, KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT,
KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE, KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE,
KEY.USE_MODAL_WISE_SHIFT: bool, KEY.USE_MODAL_WISE_SHIFT: bool,
KEY.USE_MODAL_WISE_SCALE: bool, KEY.USE_MODAL_WISE_SCALE: bool,
# KEY.DATA_SHUFFLE: bool, # KEY.DATA_SHUFFLE: bool,
KEY.COMPUTE_STATISTICS: bool, KEY.COMPUTE_STATISTICS: bool,
# KEY.DATA_WEIGHT: bool, # KEY.DATA_WEIGHT: bool,
# KEY.DATA_MODALITY: bool, # KEY.DATA_MODALITY: bool,
} }
def data_defaults(config): def data_defaults(config):
defaults = DEFAULT_DATA_CONFIG defaults = DEFAULT_DATA_CONFIG
if KEY.LOAD_VALIDSET in config: if KEY.LOAD_VALIDSET in config:
defaults.pop(KEY.RATIO, None) defaults.pop(KEY.RATIO, None)
return defaults return defaults
DEFAULT_TRAINING_CONFIG = { DEFAULT_TRAINING_CONFIG = {
KEY.RANDOM_SEED: 1, KEY.RANDOM_SEED: 1,
KEY.EPOCH: 300, KEY.EPOCH: 300,
KEY.LOSS: 'mse', KEY.LOSS: 'mse',
KEY.LOSS_PARAM: {}, KEY.LOSS_PARAM: {},
KEY.OPTIMIZER: 'adam', KEY.OPTIMIZER: 'adam',
KEY.OPTIM_PARAM: {}, KEY.OPTIM_PARAM: {},
KEY.SCHEDULER: 'exponentiallr', KEY.SCHEDULER: 'exponentiallr',
KEY.SCHEDULER_PARAM: {}, KEY.SCHEDULER_PARAM: {},
KEY.FORCE_WEIGHT: 0.1, KEY.FORCE_WEIGHT: 0.1,
KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default
KEY.PER_EPOCH: 5, KEY.PER_EPOCH: 5,
# KEY.USE_TESTSET: False, # KEY.USE_TESTSET: False,
KEY.CONTINUE: { KEY.CONTINUE: {
KEY.CHECKPOINT: False, KEY.CHECKPOINT: False,
KEY.RESET_OPTIMIZER: False, KEY.RESET_OPTIMIZER: False,
KEY.RESET_SCHEDULER: False, KEY.RESET_SCHEDULER: False,
KEY.RESET_EPOCH: False, KEY.RESET_EPOCH: False,
KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True, KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True,
KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True, KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True,
}, },
# KEY.DEFAULT_MODAL: 'common', # KEY.DEFAULT_MODAL: 'common',
KEY.CSV_LOG: 'log.csv', KEY.CSV_LOG: 'log.csv',
KEY.NUM_WORKERS: 0, KEY.NUM_WORKERS: 0,
KEY.IS_TRAIN_STRESS: True, KEY.IS_TRAIN_STRESS: True,
KEY.TRAIN_SHUFFLE: True, KEY.TRAIN_SHUFFLE: True,
KEY.ERROR_RECORD: [ KEY.ERROR_RECORD: [
['Energy', 'RMSE'], ['Energy', 'RMSE'],
['Force', 'RMSE'], ['Force', 'RMSE'],
['Stress', 'RMSE'], ['Stress', 'RMSE'],
['TotalLoss', 'None'], ['TotalLoss', 'None'],
], ],
KEY.BEST_METRIC: 'TotalLoss', KEY.BEST_METRIC: 'TotalLoss',
KEY.USE_WEIGHT: False, KEY.USE_WEIGHT: False,
KEY.USE_MODALITY: False, KEY.USE_MODALITY: False,
} }
TRAINING_CONFIG_CONDITION = { TRAINING_CONFIG_CONDITION = {
KEY.RANDOM_SEED: int, KEY.RANDOM_SEED: int,
KEY.EPOCH: int, KEY.EPOCH: int,
KEY.FORCE_WEIGHT: float, KEY.FORCE_WEIGHT: float,
KEY.STRESS_WEIGHT: float, KEY.STRESS_WEIGHT: float,
KEY.USE_TESTSET: None, # Not used KEY.USE_TESTSET: None, # Not used
KEY.NUM_WORKERS: int, KEY.NUM_WORKERS: int,
KEY.PER_EPOCH: int, KEY.PER_EPOCH: int,
KEY.CONTINUE: { KEY.CONTINUE: {
KEY.CHECKPOINT: str, KEY.CHECKPOINT: str,
KEY.RESET_OPTIMIZER: bool, KEY.RESET_OPTIMIZER: bool,
KEY.RESET_SCHEDULER: bool, KEY.RESET_SCHEDULER: bool,
KEY.RESET_EPOCH: bool, KEY.RESET_EPOCH: bool,
KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool, KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool,
KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool, KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool,
}, },
KEY.DEFAULT_MODAL: str, KEY.DEFAULT_MODAL: str,
KEY.IS_TRAIN_STRESS: bool, KEY.IS_TRAIN_STRESS: bool,
KEY.TRAIN_SHUFFLE: bool, KEY.TRAIN_SHUFFLE: bool,
KEY.ERROR_RECORD: error_record_condition, KEY.ERROR_RECORD: error_record_condition,
KEY.BEST_METRIC: str, KEY.BEST_METRIC: str,
KEY.CSV_LOG: str, KEY.CSV_LOG: str,
KEY.USE_MODALITY: bool, KEY.USE_MODALITY: bool,
KEY.USE_WEIGHT: bool, KEY.USE_WEIGHT: bool,
} }
def train_defaults(config): def train_defaults(config):
defaults = DEFAULT_TRAINING_CONFIG defaults = DEFAULT_TRAINING_CONFIG
if KEY.IS_TRAIN_STRESS not in config: if KEY.IS_TRAIN_STRESS not in config:
config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS] config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS]
if not config[KEY.IS_TRAIN_STRESS]: if not config[KEY.IS_TRAIN_STRESS]:
defaults.pop(KEY.STRESS_WEIGHT, None) defaults.pop(KEY.STRESS_WEIGHT, None)
return defaults return defaults
""" """
How to add new feature? How to add new feature?
1. Add new key to this file. 1. Add new key to this file.
2. Add new key to _const.py 2. Add new key to _const.py
2.1. if the type of input is consistent, 2.1. if the type of input is consistent,
write adequate condition and default to _const.py. write adequate condition and default to _const.py.
2.2. if the type of input is not consistent, 2.2. if the type of input is not consistent,
you must add your own input validation code to you must add your own input validation code to
parse_input.py parse_input.py
""" """
from typing import Final from typing import Final
# see # see
# https://github.com/pytorch/pytorch/issues/52312 # https://github.com/pytorch/pytorch/issues/52312
# for FYI # for FYI
# ~~ keys ~~ # # ~~ keys ~~ #
# PyG : primitive key of torch_geometric.data.Data type # PyG : primitive key of torch_geometric.data.Data type
# ==================================================# # ==================================================#
# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ #
# ==================================================# # ==================================================#
# some raw properties of graph # some raw properties of graph
ATOMIC_NUMBERS: Final[str] = 'atomic_numbers' # (N) ATOMIC_NUMBERS: Final[str] = 'atomic_numbers' # (N)
POS: Final[str] = 'pos' # (N, 3) PyG POS: Final[str] = 'pos' # (N, 3) PyG
CELL: Final[str] = 'cell_lattice_vectors' # (3, 3) CELL: Final[str] = 'cell_lattice_vectors' # (3, 3)
CELL_SHIFT: Final[str] = 'pbc_shift' # (N, 3) CELL_SHIFT: Final[str] = 'pbc_shift' # (N, 3)
CELL_VOLUME: Final[str] = 'cell_volume' CELL_VOLUME: Final[str] = 'cell_volume'
EDGE_VEC: Final[str] = 'edge_vec' # (N_edge, 3) EDGE_VEC: Final[str] = 'edge_vec' # (N_edge, 3)
EDGE_LENGTH: Final[str] = 'edge_length' # (N_edge, 1) EDGE_LENGTH: Final[str] = 'edge_length' # (N_edge, 1)
# some primary data of graph # some primary data of graph
EDGE_IDX: Final[str] = 'edge_index' # (2, N_edge) PyG EDGE_IDX: Final[str] = 'edge_index' # (2, N_edge) PyG
ATOM_TYPE: Final[str] = 'atom_type' # (N) one-hot index of nodes ATOM_TYPE: Final[str] = 'atom_type' # (N) one-hot index of nodes
NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG
NODE_FEATURE_GHOST: Final[str] = 'x_ghost' NODE_FEATURE_GHOST: Final[str] = 'x_ghost'
NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot
MODAL_ATTR: Final[str] = ( MODAL_ATTR: Final[str] = (
'modal_attr' # (1, N_modalities) for handling multi-modal 'modal_attr' # (1, N_modalities) for handling multi-modal
) )
MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal
EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics) EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics)
EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding) EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding)
# inputs of loss function # inputs of loss function
ENERGY: Final[str] = 'total_energy' # (1) ENERGY: Final[str] = 'total_energy' # (1)
FORCE: Final[str] = 'force_of_atoms' # (N, 3) FORCE: Final[str] = 'force_of_atoms' # (N, 3)
STRESS: Final[str] = 'stress' # (6) STRESS: Final[str] = 'stress' # (6)
# This is for training, per atom scale. # This is for training, per atom scale.
SCALED_ENERGY: Final[str] = 'scaled_total_energy' SCALED_ENERGY: Final[str] = 'scaled_total_energy'
# general outputs of models # general outputs of models
SCALED_ATOMIC_ENERGY: Final[str] = 'scaled_atomic_energy' SCALED_ATOMIC_ENERGY: Final[str] = 'scaled_atomic_energy'
ATOMIC_ENERGY: Final[str] = 'atomic_energy' ATOMIC_ENERGY: Final[str] = 'atomic_energy'
PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy' PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy'
PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy' PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy'
PER_ATOM_ENERGY: Final[str] = 'per_atom_energy' PER_ATOM_ENERGY: Final[str] = 'per_atom_energy'
PRED_FORCE: Final[str] = 'inferred_force' PRED_FORCE: Final[str] = 'inferred_force'
SCALED_FORCE: Final[str] = 'scaled_force' SCALED_FORCE: Final[str] = 'scaled_force'
PRED_STRESS: Final[str] = 'inferred_stress' PRED_STRESS: Final[str] = 'inferred_stress'
SCALED_STRESS: Final[str] = 'scaled_stress' SCALED_STRESS: Final[str] = 'scaled_stress'
# very general data property for AtomGraphData # very general data property for AtomGraphData
NUM_ATOMS: Final[str] = 'num_atoms' # int NUM_ATOMS: Final[str] = 'num_atoms' # int
NUM_GHOSTS: Final[str] = 'num_ghosts' NUM_GHOSTS: Final[str] = 'num_ghosts'
NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu
USER_LABEL: Final[str] = 'user_label' USER_LABEL: Final[str] = 'user_label'
DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data
DATA_MODALITY: Final[str] = ( DATA_MODALITY: Final[str] = (
'data_modality' # modality of given data. e.g. PBE and SCAN 'data_modality' # modality of given data. e.g. PBE and SCAN
) )
BATCH: Final[str] = 'batch' BATCH: Final[str] = 'batch'
TAG = 'tag' # replace USER_LABEL TAG = 'tag' # replace USER_LABEL
# etc # etc
SELF_CONNECTION_TEMP: Final[str] = 'self_cont_tmp' SELF_CONNECTION_TEMP: Final[str] = 'self_cont_tmp'
BATCH_SIZE: Final[str] = 'batch_size' BATCH_SIZE: Final[str] = 'batch_size'
INFO: Final[str] = 'data_info' INFO: Final[str] = 'data_info'
# something special # something special
LABEL_NONE: Final[str] = 'No_label' LABEL_NONE: Final[str] = 'No_label'
# ==================================================# # ==================================================#
# ~~~~~~ KEY for train/data configuration ~~~~~~~~ # # ~~~~~~ KEY for train/data configuration ~~~~~~~~ #
# ==================================================# # ==================================================#
PREPROCESS_NUM_CORES = 'preprocess_num_cores' PREPROCESS_NUM_CORES = 'preprocess_num_cores'
SAVE_DATASET = 'save_dataset_path' SAVE_DATASET = 'save_dataset_path'
SAVE_BY_LABEL = 'save_by_label' SAVE_BY_LABEL = 'save_by_label'
SAVE_BY_TRAIN_VALID = 'save_by_train_valid' SAVE_BY_TRAIN_VALID = 'save_by_train_valid'
DATA_FORMAT = 'data_format' DATA_FORMAT = 'data_format'
DATA_FORMAT_ARGS = 'data_format_args' DATA_FORMAT_ARGS = 'data_format_args'
STRUCTURE_LIST = 'structure_list' STRUCTURE_LIST = 'structure_list'
LOAD_DATASET = 'load_dataset_path' # not used in v2 LOAD_DATASET = 'load_dataset_path' # not used in v2
LOAD_TRAINSET = 'load_trainset_path' LOAD_TRAINSET = 'load_trainset_path'
LOAD_VALIDSET = 'load_validset_path' LOAD_VALIDSET = 'load_validset_path'
LOAD_TESTSET = 'load_testset_path' LOAD_TESTSET = 'load_testset_path'
FORMAT_OUTPUTS = 'format_outputs_for_ase' FORMAT_OUTPUTS = 'format_outputs_for_ase'
COMPUTE_STATISTICS = 'compute_statistics' COMPUTE_STATISTICS = 'compute_statistics'
DATASET_TYPE = 'dataset_type' DATASET_TYPE = 'dataset_type'
RANDOM_SEED = 'random_seed' RANDOM_SEED = 'random_seed'
RATIO = 'data_divide_ratio' RATIO = 'data_divide_ratio'
USE_TESTSET = 'use_testset' USE_TESTSET = 'use_testset'
EPOCH = 'epoch' EPOCH = 'epoch'
LOSS = 'loss' LOSS = 'loss'
LOSS_PARAM = 'loss_param' LOSS_PARAM = 'loss_param'
OPTIMIZER = 'optimizer' OPTIMIZER = 'optimizer'
OPTIM_PARAM = 'optim_param' OPTIM_PARAM = 'optim_param'
SCHEDULER = 'scheduler' SCHEDULER = 'scheduler'
SCHEDULER_PARAM = 'scheduler_param' SCHEDULER_PARAM = 'scheduler_param'
FORCE_WEIGHT = 'force_loss_weight' FORCE_WEIGHT = 'force_loss_weight'
STRESS_WEIGHT = 'stress_loss_weight' STRESS_WEIGHT = 'stress_loss_weight'
DEVICE = 'device' DEVICE = 'device'
DTYPE = 'dtype' DTYPE = 'dtype'
TRAIN_SHUFFLE = 'train_shuffle' TRAIN_SHUFFLE = 'train_shuffle'
IS_TRAIN_STRESS = 'is_train_stress' IS_TRAIN_STRESS = 'is_train_stress'
CONTINUE = 'continue' CONTINUE = 'continue'
CHECKPOINT = 'checkpoint' CHECKPOINT = 'checkpoint'
RESET_OPTIMIZER = 'reset_optimizer' RESET_OPTIMIZER = 'reset_optimizer'
RESET_SCHEDULER = 'reset_scheduler' RESET_SCHEDULER = 'reset_scheduler'
RESET_EPOCH = 'reset_epoch' RESET_EPOCH = 'reset_epoch'
USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint' USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint'
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = ( USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = (
'use_statistic_values_for_cp_modal_only' 'use_statistic_values_for_cp_modal_only'
) )
CSV_LOG = 'csv_log' CSV_LOG = 'csv_log'
ERROR_RECORD = 'error_record' ERROR_RECORD = 'error_record'
BEST_METRIC = 'best_metric' BEST_METRIC = 'best_metric'
NUM_WORKERS = 'num_workers' # not work NUM_WORKERS = 'num_workers' # not work
RANK = 'rank' RANK = 'rank'
LOCAL_RANK = 'local_rank' LOCAL_RANK = 'local_rank'
WORLD_SIZE = 'world_size' WORLD_SIZE = 'world_size'
IS_DDP = 'is_ddp' IS_DDP = 'is_ddp'
DDP_BACKEND = 'ddp_backend' DDP_BACKEND = 'ddp_backend'
PER_EPOCH = 'per_epoch' PER_EPOCH = 'per_epoch'
USE_WEIGHT = 'use_weight' USE_WEIGHT = 'use_weight'
USE_MODALITY = 'use_modality' USE_MODALITY = 'use_modality'
DEFAULT_MODAL = 'default_modal' DEFAULT_MODAL = 'default_modal'
# ==================================================# # ==================================================#
# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ # # ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ #
# ==================================================# # ==================================================#
# ~~ global model configuration ~~ # # ~~ global model configuration ~~ #
# note that these names are directly used for input.yaml for user input # note that these names are directly used for input.yaml for user input
MODEL_TYPE = '_model_type' MODEL_TYPE = '_model_type'
CUTOFF = 'cutoff' CUTOFF = 'cutoff'
CHEMICAL_SPECIES = 'chemical_species' CHEMICAL_SPECIES = 'chemical_species'
MODAL_LIST = 'modal_list' MODAL_LIST = 'modal_list'
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number' CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number'
NUM_SPECIES = '_number_of_species' NUM_SPECIES = '_number_of_species'
NUM_MODALITIES = '_number_of_modalities' NUM_MODALITIES = '_number_of_modalities'
TYPE_MAP = '_type_map' TYPE_MAP = '_type_map'
MODAL_MAP = '_modal_map' MODAL_MAP = '_modal_map'
# ~~ E3 equivariant model build configuration keys ~~ # # ~~ E3 equivariant model build configuration keys ~~ #
# see model_build default_config for type # see model_build default_config for type
IRREPS_MANUAL = 'irreps_manual' IRREPS_MANUAL = 'irreps_manual'
NODE_FEATURE_MULTIPLICITY = 'channel' NODE_FEATURE_MULTIPLICITY = 'channel'
RADIAL_BASIS = 'radial_basis' RADIAL_BASIS = 'radial_basis'
BESSEL_BASIS_NUM = 'bessel_basis_num' BESSEL_BASIS_NUM = 'bessel_basis_num'
CUTOFF_FUNCTION = 'cutoff_function' CUTOFF_FUNCTION = 'cutoff_function'
POLY_CUT_P = 'poly_cut_p_value' POLY_CUT_P = 'poly_cut_p_value'
LMAX = 'lmax' LMAX = 'lmax'
LMAX_EDGE = 'lmax_edge' LMAX_EDGE = 'lmax_edge'
LMAX_NODE = 'lmax_node' LMAX_NODE = 'lmax_node'
IS_PARITY = 'is_parity' IS_PARITY = 'is_parity'
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS = 'weight_nn_hidden_neurons' CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS = 'weight_nn_hidden_neurons'
NUM_CONVOLUTION = 'num_convolution_layer' NUM_CONVOLUTION = 'num_convolution_layer'
ACTIVATION_SCARLAR = 'act_scalar' ACTIVATION_SCARLAR = 'act_scalar'
ACTIVATION_GATE = 'act_gate' ACTIVATION_GATE = 'act_gate'
ACTIVATION_RADIAL = 'act_radial' ACTIVATION_RADIAL = 'act_radial'
SELF_CONNECTION_TYPE = 'self_connection_type' SELF_CONNECTION_TYPE = 'self_connection_type'
RADIAL_BASIS_NAME = 'radial_basis_name' RADIAL_BASIS_NAME = 'radial_basis_name'
CUTOFF_FUNCTION_NAME = 'cutoff_function_name' CUTOFF_FUNCTION_NAME = 'cutoff_function_name'
USE_BIAS_IN_LINEAR = 'use_bias_in_linear' USE_BIAS_IN_LINEAR = 'use_bias_in_linear'
USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding' USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding'
USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro' USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro'
USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro' USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro'
USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block' USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block'
READOUT_AS_FCN = 'readout_as_fcn' READOUT_AS_FCN = 'readout_as_fcn'
READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons' READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons'
READOUT_FCN_ACTIVATION = 'readout_fcn_activation' READOUT_FCN_ACTIVATION = 'readout_fcn_activation'
AVG_NUM_NEIGH = 'avg_num_neigh' AVG_NUM_NEIGH = 'avg_num_neigh'
CONV_DENOMINATOR = 'conv_denominator' CONV_DENOMINATOR = 'conv_denominator'
SHIFT = 'shift' SHIFT = 'shift'
SCALE = 'scale' SCALE = 'scale'
USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale' USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale'
USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift' USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift'
USE_MODAL_WISE_SCALE = 'use_modal_wise_scale' USE_MODAL_WISE_SCALE = 'use_modal_wise_scale'
TRAIN_SHIFT_SCALE = 'train_shift_scale' TRAIN_SHIFT_SCALE = 'train_shift_scale'
TRAIN_DENOMINTAOR = 'train_denominator' TRAIN_DENOMINTAOR = 'train_denominator'
INTERACTION_TYPE = 'interaction_type' INTERACTION_TYPE = 'interaction_type'
TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated
CUEQUIVARIANCE_CONFIG = 'cuequivariance_config' CUEQUIVARIANCE_CONFIG = 'cuequivariance_config'
_NORMALIZE_SPH = '_normalize_sph' _NORMALIZE_SPH = '_normalize_sph'
OPTIMIZE_BY_REDUCE = 'optimize_by_reduce' OPTIMIZE_BY_REDUCE = 'optimize_by_reduce'
from typing import Optional from typing import Optional
import torch import torch
import torch_geometric.data import torch_geometric.data
import sevenn._keys as KEY import sevenn._keys as KEY
import sevenn.util import sevenn.util
class AtomGraphData(torch_geometric.data.Data): class AtomGraphData(torch_geometric.data.Data):
""" """
Args: Args:
x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes, x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes,
atomic_numbers]`. (default: :obj:`None`) atomic_numbers]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in coordinate edge_index (LongTensor, optional): Graph connectivity in coordinate
format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) format with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y_energy: scalar # unit of eV (VASP raw) y_energy: scalar # unit of eV (VASP raw)
y_force: [num_nodes, 3] # unit of eV/A (VASP raw) y_force: [num_nodes, 3] # unit of eV/A (VASP raw)
y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw) y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw)
pos (Tensor, optional): Node position matrix with shape pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes. **kwargs (optional): Additional attributes.
x, y_force, pos should be aligned with each other. x, y_force, pos should be aligned with each other.
""" """
def __init__( def __init__(
self, self,
x: Optional[torch.Tensor] = None, x: Optional[torch.Tensor] = None,
edge_index: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None,
pos: Optional[torch.Tensor] = None, pos: Optional[torch.Tensor] = None,
edge_attr: Optional[torch.Tensor] = None, edge_attr: Optional[torch.Tensor] = None,
**kwargs **kwargs
): ):
super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos) super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos)
self[KEY.NODE_ATTR] = x # ? self[KEY.NODE_ATTR] = x # ?
for k, v in kwargs.items(): for k, v in kwargs.items():
self[k] = v self[k] = v
def to_numpy_dict(self): def to_numpy_dict(self):
# This is not debugged yet! # This is not debugged yet!
dct = { dct = {
k: v.detach().cpu().numpy() if type(v) is torch.Tensor else v k: v.detach().cpu().numpy() if type(v) is torch.Tensor else v
for k, v in self.items() for k, v in self.items()
} }
return dct return dct
def fit_dimension(self): def fit_dimension(self):
per_atom_keys = [ per_atom_keys = [
KEY.ATOMIC_NUMBERS, KEY.ATOMIC_NUMBERS,
KEY.ATOMIC_ENERGY, KEY.ATOMIC_ENERGY,
KEY.POS, KEY.POS,
KEY.FORCE, KEY.FORCE,
KEY.PRED_FORCE, KEY.PRED_FORCE,
] ]
natoms = self.num_atoms.item() natoms = self.num_atoms.item()
for k, v in self.items(): for k, v in self.items():
if not isinstance(v, torch.Tensor): if not isinstance(v, torch.Tensor):
continue continue
if natoms == 1 and k in per_atom_keys: if natoms == 1 and k in per_atom_keys:
self[k] = v.squeeze().unsqueeze(0) self[k] = v.squeeze().unsqueeze(0)
else: else:
self[k] = v.squeeze() self[k] = v.squeeze()
return self return self
@staticmethod @staticmethod
def from_numpy_dict(dct): def from_numpy_dict(dct):
for k, v in dct.items(): for k, v in dct.items():
if k == KEY.CELL_SHIFT: if k == KEY.CELL_SHIFT:
dct[k] = torch.Tensor(v) # this is special dct[k] = torch.Tensor(v) # this is special
else: else:
dct[k] = sevenn.util.dtype_correct(v) dct[k] = sevenn.util.dtype_correct(v)
return AtomGraphData(**dct) return AtomGraphData(**dct)
import ctypes import ctypes
import os import os
import pathlib import pathlib
import warnings import warnings
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.jit import torch.jit
import torch.jit._script import torch.jit._script
from ase.calculators.calculator import Calculator, all_changes from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.mixing import SumCalculator from ase.calculators.mixing import SumCalculator
from ase.data import chemical_symbols from ase.data import chemical_symbols
import sevenn._keys as KEY import sevenn._keys as KEY
import sevenn.util as util import sevenn.util as util
from sevenn.atom_graph_data import AtomGraphData from sevenn.atom_graph_data import AtomGraphData
from sevenn.nn.sequential import AtomGraphSequential from sevenn.nn.sequential import AtomGraphSequential
from sevenn.train.dataload import unlabeled_atoms_to_graph from sevenn.train.dataload import unlabeled_atoms_to_graph
import logging import logging
torch_script_type = torch.jit._script.RecursiveScriptModule torch_script_type = torch.jit._script.RecursiveScriptModule
class SevenNetCalculator(Calculator): class SevenNetCalculator(Calculator):
"""Supporting properties: """Supporting properties:
'free_energy', 'energy', 'forces', 'stress', 'energies' 'free_energy', 'energy', 'forces', 'stress', 'energies'
free_energy equals energy. 'energies' stores atomic energy. free_energy equals energy. 'energies' stores atomic energy.
Multi-GPU acceleration is not supported with ASE calculator. Multi-GPU acceleration is not supported with ASE calculator.
You should use LAMMPS for the acceleration. You should use LAMMPS for the acceleration.
""" """
def __init__( def __init__(
self, self,
model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0',
file_type: str = 'checkpoint', file_type: str = 'checkpoint',
device: Union[torch.device, str] = 'auto', device: Union[torch.device, str] = 'auto',
modal: Optional[str] = None, modal: Optional[str] = None,
enable_cueq: bool = False, enable_cueq: bool = False,
sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info
**kwargs, **kwargs,
): ):
"""Initialize SevenNetCalculator. """Initialize SevenNetCalculator.
Parameters Parameters
---------- ----------
model: str | Path | AtomGraphSequential, default='7net-0' model: str | Path | AtomGraphSequential, default='7net-0'
Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or
path to the checkpoint, deployed model or the model itself path to the checkpoint, deployed model or the model itself
file_type: str, default='checkpoint' file_type: str, default='checkpoint'
one of 'checkpoint' | 'torchscript' | 'model_instance' one of 'checkpoint' | 'torchscript' | 'model_instance'
device: str | torch.device, default='auto' device: str | torch.device, default='auto'
if not given, use CUDA if available if not given, use CUDA if available
modal: str | None, default=None modal: str | None, default=None
modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa,
it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24)
case insensitive case insensitive
enable_cueq: bool, default=False enable_cueq: bool, default=False
if True, use cuEquivariant to accelerate inference. if True, use cuEquivariant to accelerate inference.
sevennet_config: dict | None, default=None sevennet_config: dict | None, default=None
Not used, but can be used to carry meta information of this calculator Not used, but can be used to carry meta information of this calculator
""" """
print("&&& Initializing SevenNetCalculator") print("&&& Initializing SevenNetCalculator")
super().__init__(**kwargs) super().__init__(**kwargs)
self.sevennet_config = None self.sevennet_config = None
if isinstance(model, pathlib.PurePath): if isinstance(model, pathlib.PurePath):
model = str(model) model = str(model)
allowed_file_types = ['checkpoint', 'torchscript', 'model_instance'] allowed_file_types = ['checkpoint', 'torchscript', 'model_instance']
file_type = file_type.lower() file_type = file_type.lower()
if file_type not in allowed_file_types: if file_type not in allowed_file_types:
raise ValueError(f'file_type not in {allowed_file_types}') raise ValueError(f'file_type not in {allowed_file_types}')
if enable_cueq and file_type in ['model_instance', 'torchscript']: if enable_cueq and file_type in ['model_instance', 'torchscript']:
warnings.warn( warnings.warn(
'file_type should be checkpoint to enable cueq. cueq set to False' 'file_type should be checkpoint to enable cueq. cueq set to False'
) )
enable_cueq = False enable_cueq = False
if isinstance(device, str): # TODO: do we really need this? if isinstance(device, str): # TODO: do we really need this?
if device == 'auto': if device == 'auto':
self.device = torch.device( self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu' 'cuda' if torch.cuda.is_available() else 'cpu'
) )
else: else:
self.device = torch.device(device) self.device = torch.device(device)
else: else:
self.device = device self.device = device
if file_type == 'checkpoint' and isinstance(model, str): if file_type == 'checkpoint' and isinstance(model, str):
cp = util.load_checkpoint(model) cp = util.load_checkpoint(model)
backend = 'e3nn' if not enable_cueq else 'cueq' backend = 'e3nn' if not enable_cueq else 'cueq'
model_loaded = cp.build_model(backend) model_loaded = cp.build_model(backend)
model_loaded.set_is_batch_data(False) model_loaded.set_is_batch_data(False)
self.type_map = cp.config[KEY.TYPE_MAP] self.type_map = cp.config[KEY.TYPE_MAP]
self.cutoff = cp.config[KEY.CUTOFF] self.cutoff = cp.config[KEY.CUTOFF]
self.sevennet_config = cp.config self.sevennet_config = cp.config
elif file_type == 'torchscript' and isinstance(model, str): elif file_type == 'torchscript' and isinstance(model, str):
if modal: if modal:
raise NotImplementedError() raise NotImplementedError()
extra_dict = { extra_dict = {
'chemical_symbols_to_index': b'', 'chemical_symbols_to_index': b'',
'cutoff': b'', 'cutoff': b'',
'num_species': b'', 'num_species': b'',
'model_type': b'', 'model_type': b'',
'version': b'', 'version': b'',
'dtype': b'', 'dtype': b'',
'time': b'', 'time': b'',
} }
model_loaded = torch.jit.load( model_loaded = torch.jit.load(
model, _extra_files=extra_dict, map_location=self.device model, _extra_files=extra_dict, map_location=self.device
) )
chem_symbols = extra_dict['chemical_symbols_to_index'].decode('utf-8') chem_symbols = extra_dict['chemical_symbols_to_index'].decode('utf-8')
sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)} sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)}
self.type_map = { self.type_map = {
sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split())
} }
self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) self.cutoff = float(extra_dict['cutoff'].decode('utf-8'))
elif isinstance(model, AtomGraphSequential): elif isinstance(model, AtomGraphSequential):
if model.type_map is None: if model.type_map is None:
raise ValueError( raise ValueError(
'Model must have the type_map to be used with calculator' 'Model must have the type_map to be used with calculator'
) )
if model.cutoff == 0.0: if model.cutoff == 0.0:
raise ValueError('Model cutoff seems not initialized') raise ValueError('Model cutoff seems not initialized')
model.eval_type_map = torch.tensor(True) # ? model.eval_type_map = torch.tensor(True) # ?
model.set_is_batch_data(False) model.set_is_batch_data(False)
model_loaded = model model_loaded = model
self.type_map = model.type_map self.type_map = model.type_map
self.cutoff = model.cutoff self.cutoff = model.cutoff
else: else:
raise ValueError('Unexpected input combinations') raise ValueError('Unexpected input combinations')
if self.sevennet_config is None and sevennet_config is not None: if self.sevennet_config is None and sevennet_config is not None:
self.sevennet_config = sevennet_config self.sevennet_config = sevennet_config
self.model = model_loaded self.model = model_loaded
self.modal = None self.modal = None
if isinstance(self.model, AtomGraphSequential): if isinstance(self.model, AtomGraphSequential):
modal_map = self.model.modal_map modal_map = self.model.modal_map
if modal_map: if modal_map:
modal_ava = list(modal_map.keys()) modal_ava = list(modal_map.keys())
if not modal: if not modal:
raise ValueError(f'modal argument missing (avail: {modal_ava})') raise ValueError(f'modal argument missing (avail: {modal_ava})')
elif modal not in modal_ava: elif modal not in modal_ava:
raise ValueError(f'unknown modal {modal} (not in {modal_ava})') raise ValueError(f'unknown modal {modal} (not in {modal_ava})')
self.modal = modal self.modal = modal
elif not self.model.modal_map and modal: elif not self.model.modal_map and modal:
warnings.warn(f'modal={modal} is ignored as model has no modal_map') warnings.warn(f'modal={modal} is ignored as model has no modal_map')
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
self.implemented_properties = [ self.implemented_properties = [
'free_energy', 'free_energy',
'energy', 'energy',
'forces', 'forces',
'stress', 'stress',
'energies', 'energies',
] ]
def set_atoms(self, atoms): def set_atoms(self, atoms):
# called by ase, when atoms.calc = calc # called by ase, when atoms.calc = calc
zs = tuple(set(atoms.get_atomic_numbers())) zs = tuple(set(atoms.get_atomic_numbers()))
for z in zs: for z in zs:
if z not in self.type_map: if z not in self.type_map:
sp = list(self.type_map.keys()) sp = list(self.type_map.keys())
raise ValueError( raise ValueError(
f'Model do not know atomic number: {z}, (knows: {sp})' f'Model do not know atomic number: {z}, (knows: {sp})'
) )
def output_to_results(self, output): def output_to_results(self, output):
energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item() energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item()
num_atoms = output['num_atoms'].item() num_atoms = output['num_atoms'].item()
atomic_energies = output[KEY.ATOMIC_ENERGY].detach().cpu().numpy().flatten() atomic_energies = output[KEY.ATOMIC_ENERGY].detach().cpu().numpy().flatten()
forces = output[KEY.PRED_FORCE].detach().cpu().numpy()[:num_atoms, :] forces = output[KEY.PRED_FORCE].detach().cpu().numpy()[:num_atoms, :]
stress = np.array( stress = np.array(
(-output[KEY.PRED_STRESS]) (-output[KEY.PRED_STRESS])
.detach() .detach()
.cpu() .cpu()
.numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation
) )
# Store results # Store results
return { return {
'free_energy': energy, 'free_energy': energy,
'energy': energy, 'energy': energy,
'energies': atomic_energies, 'energies': atomic_energies,
'forces': forces, 'forces': forces,
'stress': stress, 'stress': stress,
'num_edges': output[KEY.EDGE_IDX].shape[1], 'num_edges': output[KEY.EDGE_IDX].shape[1],
} }
def calculate(self, atoms=None, properties=None, system_changes=all_changes): def calculate(self, atoms=None, properties=None, system_changes=all_changes):
# call parent class to set necessary atom attributes # call parent class to set necessary atom attributes
Calculator.calculate(self, atoms, properties, system_changes) Calculator.calculate(self, atoms, properties, system_changes)
if atoms is None: if atoms is None:
raise ValueError('No atoms to evaluate') raise ValueError('No atoms to evaluate')
data = AtomGraphData.from_numpy_dict( data = AtomGraphData.from_numpy_dict(
unlabeled_atoms_to_graph(atoms, self.cutoff) unlabeled_atoms_to_graph(atoms, self.cutoff)
) )
if self.modal: if self.modal:
data[KEY.DATA_MODALITY] = self.modal data[KEY.DATA_MODALITY] = self.modal
data.to(self.device) # type: ignore data.to(self.device) # type: ignore
if isinstance(self.model, torch_script_type): if isinstance(self.model, torch_script_type):
data[KEY.NODE_FEATURE] = torch.tensor( data[KEY.NODE_FEATURE] = torch.tensor(
[self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]],
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )
data[KEY.POS].requires_grad_(True) # backward compatibility data[KEY.POS].requires_grad_(True) # backward compatibility
data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility
data = data.to_dict() data = data.to_dict()
del data['data_info'] del data['data_info']
import logging import logging
logging.debug(f"data: {data}") logging.debug(f"data: {data}")
# logging.debug(f"data[pos]: {data['pos']}") # logging.debug(f"data[pos]: {data['pos']}")
# logging.debug(f"data[x]: {data['x']}") # logging.debug(f"data[x]: {data['x']}")
logging.debug(f"data[cell_lattice_vectors]: {data['cell_lattice_vectors']}") logging.debug(f"data[cell_lattice_vectors]: {data['cell_lattice_vectors']}")
logging.debug(f"data[cell_volume]: {data['cell_volume']}") logging.debug(f"data[cell_volume]: {data['cell_volume']}")
output = self.model(data) output = self.model(data)
# logging.info(f"input: {data}") # logging.info(f"input: {data}")
# logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}")
# logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}")
# logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}")
self.results = self.output_to_results(output) self.results = self.output_to_results(output)
# logging.debug(f"results['energy'] = {self.results['energy']}") # logging.debug(f"results['energy'] = {self.results['energy']}")
# logging.debug(f"results['forces'] = {self.results['forces']}") # logging.debug(f"results['forces'] = {self.results['forces']}")
# logging.debug(f"results['stress'] = {self.results['stress']}") # logging.debug(f"results['stress'] = {self.results['stress']}")
def predict_one(self, atoms): def predict_one(self, atoms):
if atoms is None: if atoms is None:
raise ValueError('No atoms to evaluate') raise ValueError('No atoms to evaluate')
data = AtomGraphData.from_numpy_dict( data = AtomGraphData.from_numpy_dict(
unlabeled_atoms_to_graph(atoms, self.cutoff) unlabeled_atoms_to_graph(atoms, self.cutoff)
) )
if self.modal: if self.modal:
data[KEY.DATA_MODALITY] = self.modal data[KEY.DATA_MODALITY] = self.modal
data.to(self.device) # type: ignore data.to(self.device) # type: ignore
if isinstance(self.model, torch_script_type): if isinstance(self.model, torch_script_type):
data[KEY.NODE_FEATURE] = torch.tensor( data[KEY.NODE_FEATURE] = torch.tensor(
[self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]],
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )
data[KEY.POS].requires_grad_(True) # backward compatibility data[KEY.POS].requires_grad_(True) # backward compatibility
data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility
data = data.to_dict() data = data.to_dict()
del data['data_info'] del data['data_info']
return self.model(data) return self.model(data)
def predict(self, atoms_list, properties=None): def predict(self, atoms_list, properties=None):
# if len(atoms_list) == 1: # if len(atoms_list) == 1:
# output = self.predict_one(atoms_list[0]) # output = self.predict_one(atoms_list[0])
# predictions = {} # predictions = {}
# predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).unsqueeze(0) # predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).unsqueeze(0)
# predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).unsqueeze(0) # predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).unsqueeze(0)
# voigt = (-output[KEY.PRED_STRESS])[[0, 1, 2, 4, 5, 3]].to(torch.float64).unsqueeze(0) # voigt = (-output[KEY.PRED_STRESS])[[0, 1, 2, 4, 5, 3]].to(torch.float64).unsqueeze(0)
# stress_list = [] # stress_list = []
# for i in range(voigt.shape[0]): # for i in range(voigt.shape[0]):
# stress_list.append(self._stress2tensor(voigt[i,:])) # stress_list.append(self._stress2tensor(voigt[i,:]))
# predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3) # predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3)
# return predictions # return predictions
if not atoms_list: if not atoms_list:
raise ValueError("Empty atoms_list provided") raise ValueError("Empty atoms_list provided")
if not isinstance(atoms_list, list): if not isinstance(atoms_list, list):
atoms_list = [atoms_list] atoms_list = [atoms_list]
# Convert atoms to graph data # Convert atoms to graph data
graph_list = [] graph_list = []
for atoms in atoms_list: for atoms in atoms_list:
data = AtomGraphData.from_numpy_dict( data = AtomGraphData.from_numpy_dict(
unlabeled_atoms_to_graph(atoms, self.cutoff) unlabeled_atoms_to_graph(atoms, self.cutoff)
) )
if self.modal: if self.modal:
data[KEY.DATA_MODALITY] = self.modal data[KEY.DATA_MODALITY] = self.modal
if isinstance(self.model, torch_script_type): if isinstance(self.model, torch_script_type):
data[KEY.NODE_FEATURE] = torch.tensor( data[KEY.NODE_FEATURE] = torch.tensor(
[self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]],
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )
data[KEY.POS].requires_grad_(True) # backward compatibility data[KEY.POS].requires_grad_(True) # backward compatibility
data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility
graph_list.append(data) graph_list.append(data)
# Process graphs based on model type # Process graphs based on model type
# was_batch_mode = True # was_batch_mode = True
if isinstance(self.model, AtomGraphSequential): if isinstance(self.model, AtomGraphSequential):
# was_batch_mode = self.model.is_batch_data # was_batch_mode = self.model.is_batch_data
self.model.set_is_batch_data(True) self.model.set_is_batch_data(True)
self.model.eval() self.model.eval()
# Batch the data if there are multiple atoms # Batch the data if there are multiple atoms
from torch_geometric.loader.dataloader import Collater from torch_geometric.loader.dataloader import Collater
batched_data = Collater(graph_list)(graph_list) batched_data = Collater(graph_list)(graph_list)
batched_data = batched_data.to(self.device) batched_data = batched_data.to(self.device)
import logging import logging
logging.debug(f"batched_data: {batched_data}") logging.debug(f"batched_data: {batched_data}")
# logging.debug(f"batched_data[pos]: {batched_data['pos']}") # logging.debug(f"batched_data[pos]: {batched_data['pos']}")
# logging.debug(f"batched_data[x]: {batched_data['x']}") # logging.debug(f"batched_data[x]: {batched_data['x']}")
logging.debug(f"batched_data[cell_lattice_vectors]: {batched_data['cell_lattice_vectors']}") logging.debug(f"batched_data[cell_lattice_vectors]: {batched_data['cell_lattice_vectors']}")
logging.debug(f"batched_data[cell_volume]: {batched_data['cell_volume']}") logging.debug(f"batched_data[cell_volume]: {batched_data['cell_volume']}")
# Run model on batched data # Run model on batched data
if isinstance(self.model, torch_script_type): if isinstance(self.model, torch_script_type):
batched_dict = batched_data.to_dict() batched_dict = batched_data.to_dict()
if 'data_info' in batched_dict: if 'data_info' in batched_dict:
del batched_dict['data_info'] del batched_dict['data_info']
output = self.model(batched_dict) output = self.model(batched_dict)
else: else:
output = self.model(batched_data) output = self.model(batched_data)
# Convert to list of individual outputs using util.to_atom_graph_list # Convert to list of individual outputs using util.to_atom_graph_list
# logging.info(f"input: {batched_data}") # logging.info(f"input: {batched_data}")
# logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}")
# logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}")
# logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}")
predictions = {} predictions = {}
predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).detach() predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).detach()
predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).detach() predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).detach()
voigt = (-output[KEY.PRED_STRESS])[:, [0, 1, 2, 4, 5, 3]].to(torch.float64).detach() voigt = (-output[KEY.PRED_STRESS])[:, [0, 1, 2, 4, 5, 3]].to(torch.float64).detach()
stress_list = [] stress_list = []
for i in range(voigt.shape[0]): for i in range(voigt.shape[0]):
stress_list.append(self._stress2tensor(voigt[i,:])) stress_list.append(self._stress2tensor(voigt[i,:]))
predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3).detach() predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3).detach()
# logging.debug(f"predictions['energy'] = {predictions['energy']}") # logging.debug(f"predictions['energy'] = {predictions['energy']}")
# logging.debug(f"predictions['forces'] = {predictions['forces']}") # logging.debug(f"predictions['forces'] = {predictions['forces']}")
# logging.debug(f"predictions['stress'] = {predictions['stress']}") # logging.debug(f"predictions['stress'] = {predictions['stress']}")
return predictions return predictions
def _stress2tensor(self, stress): def _stress2tensor(self, stress):
tensor = torch.tensor( tensor = torch.tensor(
[ [
[stress[0], stress[5], stress[4]], [stress[0], stress[5], stress[4]],
[stress[5], stress[1], stress[3]], [stress[5], stress[1], stress[3]],
[stress[4], stress[3], stress[2]], [stress[4], stress[3], stress[2]],
], ],
device=self.device device=self.device
) )
return tensor return tensor
class SevenNetD3Calculator(SumCalculator): class SevenNetD3Calculator(SumCalculator):
def __init__( def __init__(
self, self,
model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0',
file_type: str = 'checkpoint', file_type: str = 'checkpoint',
device: Union[torch.device, str] = 'auto', device: Union[torch.device, str] = 'auto',
sevennet_config: Optional[Any] = None, # hold meta information sevennet_config: Optional[Any] = None, # hold meta information
damping_type: str = 'damp_bj', damping_type: str = 'damp_bj',
functional_name: str = 'pbe', functional_name: str = 'pbe',
vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au
cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au
batch_size=10, batch_size=10,
**kwargs, **kwargs,
): ):
"""Initialize SevenNetD3Calculator. CUDA required. """Initialize SevenNetD3Calculator. CUDA required.
Parameters Parameters
---------- ----------
model: str | Path | AtomGraphSequential model: str | Path | AtomGraphSequential
Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or
path to the checkpoint, deployed model or the model itself path to the checkpoint, deployed model or the model itself
file_type: str, default='checkpoint' file_type: str, default='checkpoint'
one of 'checkpoint' | 'torchscript' | 'model_instance' one of 'checkpoint' | 'torchscript' | 'model_instance'
device: str | torch.device, default='auto' device: str | torch.device, default='auto'
if not given, use CUDA if available if not given, use CUDA if available
modal: str | None, default=None modal: str | None, default=None
modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa,
it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24)
enable_cueq: bool, default=False enable_cueq: bool, default=False
if True, use cuEquivariant to accelerate inference. if True, use cuEquivariant to accelerate inference.
damping_type: str, default='damp_bj' damping_type: str, default='damp_bj'
Damping type of D3, one of 'damp_bj' | 'damp_zero' Damping type of D3, one of 'damp_bj' | 'damp_zero'
functional_name: str, default='pbe' functional_name: str, default='pbe'
Target functional name of D3 parameters. Target functional name of D3 parameters.
vdw_cutoff: float, default=9000 vdw_cutoff: float, default=9000
vdw cutoff of D3 calculator in au vdw cutoff of D3 calculator in au
cn_cutoff: float, default=1600 cn_cutoff: float, default=1600
cn cutoff of D3 calculator in au cn cutoff of D3 calculator in au
""" """
self.d3_calc = D3Calculator( self.d3_calc = D3Calculator(
damping_type=damping_type, damping_type=damping_type,
functional_name=functional_name, functional_name=functional_name,
vdw_cutoff=vdw_cutoff, vdw_cutoff=vdw_cutoff,
cn_cutoff=cn_cutoff, cn_cutoff=cn_cutoff,
**kwargs, **kwargs,
) )
self.sevennet_calc = SevenNetCalculator( self.sevennet_calc = SevenNetCalculator(
model=model, model=model,
file_type=file_type, file_type=file_type,
device=device, device=device,
sevennet_config=sevennet_config, sevennet_config=sevennet_config,
**kwargs, **kwargs,
) )
super().__init__([self.sevennet_calc, self.d3_calc]) super().__init__([self.sevennet_calc, self.d3_calc])
self.device = device self.device = device
self.d3_calcs = [] self.d3_calcs = []
for _ in range(batch_size): for _ in range(batch_size):
self.d3_calcs.append( self.d3_calcs.append(
D3Calculator( D3Calculator(
damping_type=damping_type, damping_type=damping_type,
functional_name=functional_name, functional_name=functional_name,
vdw_cutoff=vdw_cutoff, vdw_cutoff=vdw_cutoff,
cn_cutoff=cn_cutoff, cn_cutoff=cn_cutoff,
**kwargs, **kwargs,
) )
) )
def predict(self, atoms_list): def predict(self, atoms_list):
"""Predict the energy and forces for a list of atoms. """Predict the energy and forces for a list of atoms.
""" """
# Call the predict method of the first calculator (SevenNetCalculator) # Call the predict method of the first calculator (SevenNetCalculator)
predictions = self.sevennet_calc.predict(atoms_list) predictions = self.sevennet_calc.predict(atoms_list)
energy_list = [] energy_list = []
forces_list = [] forces_list = []
stress_list = [] stress_list = []
predictions3d = {} predictions3d = {}
for i, atoms in enumerate(atoms_list): for i, atoms in enumerate(atoms_list):
prediction = self.d3_calcs[i].predict_one(atoms) prediction = self.d3_calcs[i].predict_one(atoms)
energy_list.append(torch.tensor(prediction['energy'])) energy_list.append(torch.tensor(prediction['energy']))
forces_list.append(torch.from_numpy(prediction['forces']).to(self.device)) forces_list.append(torch.from_numpy(prediction['forces']).to(self.device))
stress_list.append(self._stress2tensor(torch.from_numpy(prediction['stress']))) stress_list.append(self._stress2tensor(torch.from_numpy(prediction['stress'])))
# Convert lists to tensors # Convert lists to tensors
predictions3d['energy'] = torch.stack(energy_list, dim=0).to(self.device) predictions3d['energy'] = torch.stack(energy_list, dim=0).to(self.device)
predictions3d['forces'] = torch.cat(forces_list, dim=0).view(-1, 3) predictions3d['forces'] = torch.cat(forces_list, dim=0).view(-1, 3)
predictions3d['stress'] = torch.stack(stress_list, dim=0).view(-1, 3, 3) predictions3d['stress'] = torch.stack(stress_list, dim=0).view(-1, 3, 3)
predictions['energy'] += predictions3d['energy'].detach() predictions['energy'] += predictions3d['energy'].detach()
predictions['forces'] += predictions3d['forces'].detach() predictions['forces'] += predictions3d['forces'].detach()
predictions['stress'] += predictions3d['stress'].detach() predictions['stress'] += predictions3d['stress'].detach()
return predictions return predictions
def _stress2tensor(self, stress): def _stress2tensor(self, stress):
tensor = torch.tensor( tensor = torch.tensor(
[ [
# [stress[0], stress[3], stress[4]], # [stress[0], stress[3], stress[4]],
# [stress[3], stress[1], stress[5]], # [stress[3], stress[1], stress[5]],
# [stress[4], stress[5], stress[2]], # [stress[4], stress[5], stress[2]],
[stress[0], stress[5], stress[4]], [stress[0], stress[5], stress[4]],
[stress[5], stress[1], stress[3]], [stress[5], stress[1], stress[3]],
[stress[4], stress[3], stress[2]], [stress[4], stress[3], stress[2]],
], ],
device=self.device device=self.device
) )
return tensor return tensor
def _load(name: str) -> ctypes.CDLL: def _load(name: str) -> ctypes.CDLL:
from torch.utils.cpp_extension import LIB_EXT, _get_build_directory, load from torch.utils.cpp_extension import LIB_EXT, _get_build_directory, load
# Load the library from the candidate locations # Load the library from the candidate locations
package_dir = os.path.dirname(os.path.abspath(__file__)) package_dir = os.path.dirname(os.path.abspath(__file__))
try: try:
return ctypes.CDLL(os.path.join(package_dir, f'{name}{LIB_EXT}')) return ctypes.CDLL(os.path.join(package_dir, f'{name}{LIB_EXT}'))
except OSError: except OSError:
pass pass
cache_dir = _get_build_directory(name, verbose=False) cache_dir = _get_build_directory(name, verbose=False)
try: try:
return ctypes.CDLL(os.path.join(cache_dir, f'{name}{LIB_EXT}')) return ctypes.CDLL(os.path.join(cache_dir, f'{name}{LIB_EXT}'))
except OSError: except OSError:
pass pass
# Compile the library if it is not found # Compile the library if it is not found
if os.access(package_dir, os.W_OK): if os.access(package_dir, os.W_OK):
compile_dir = package_dir compile_dir = package_dir
else: else:
print('Warning: package directory is not writable. Using cache directory.') print('Warning: package directory is not writable. Using cache directory.')
compile_dir = cache_dir compile_dir = cache_dir
if 'TORCH_CUDA_ARCH_LIST' not in os.environ: if 'TORCH_CUDA_ARCH_LIST' not in os.environ:
print('Warning: TORCH_CUDA_ARCH_LIST is not set.') print('Warning: TORCH_CUDA_ARCH_LIST is not set.')
print('Warning: Use default CUDA architectures: 61, 70, 75, 80, 86, 89, 90') print('Warning: Use default CUDA architectures: 61, 70, 75, 80, 86, 89, 90')
os.environ['TORCH_CUDA_ARCH_LIST'] = '6.1;7.0;7.5;8.0;8.6;8.9;9.0' os.environ['TORCH_CUDA_ARCH_LIST'] = '6.1;7.0;7.5;8.0;8.6;8.9;9.0'
load( load(
name=name, name=name,
sources=[os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu')], sources=[os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu')],
extra_cuda_cflags=['-O3', '--expt-relaxed-constexpr', '-fmad=false'], extra_cuda_cflags=['-O3', '--expt-relaxed-constexpr', '-fmad=false'],
build_directory=compile_dir, build_directory=compile_dir,
verbose=True, verbose=True,
is_python_module=False, is_python_module=False,
) )
return ctypes.CDLL(os.path.join(compile_dir, f'{name}{LIB_EXT}')) return ctypes.CDLL(os.path.join(compile_dir, f'{name}{LIB_EXT}'))
class PairD3(ctypes.Structure): class PairD3(ctypes.Structure):
pass # Opaque structure; only used as a pointer pass # Opaque structure; only used as a pointer
class D3Calculator(Calculator): class D3Calculator(Calculator):
"""ASE calculator for accelerated D3 van der Waals (vdW) correction. """ASE calculator for accelerated D3 van der Waals (vdW) correction.
Example: Example:
from ase.calculators.mixing import SumCalculator from ase.calculators.mixing import SumCalculator
calc_1 = SevenNetCalculator() calc_1 = SevenNetCalculator()
calc_2 = D3Calculator() calc_2 = D3Calculator()
return SumCalculator([calc_1, calc_2]) return SumCalculator([calc_1, calc_2])
This calculator interfaces with the `libpaird3.so` library, This calculator interfaces with the `libpaird3.so` library,
which is compiled by nvcc during the package installation. which is compiled by nvcc during the package installation.
If you encounter any errors, please verify If you encounter any errors, please verify
the installation process and the compilation options in `setup.py`. the installation process and the compilation options in `setup.py`.
Note: Multi-GPU parallel MD is not supported in this mode. Note: Multi-GPU parallel MD is not supported in this mode.
Note: Cffi could be used, but it was avoided to reduce dependencies. Note: Cffi could be used, but it was avoided to reduce dependencies.
""" """
# Here, free_energy = energy # Here, free_energy = energy
implemented_properties = ['free_energy', 'energy', 'forces', 'stress'] implemented_properties = ['free_energy', 'energy', 'forces', 'stress']
def __init__( def __init__(
self, self,
damping_type: str = 'damp_bj', # damp_bj, damp_zero damping_type: str = 'damp_bj', # damp_bj, damp_zero
functional_name: str = 'pbe', # check the source code functional_name: str = 'pbe', # check the source code
vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au
cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise NotImplementedError('CPU + D3 is not implemented yet') raise NotImplementedError('CPU + D3 is not implemented yet')
self.rthr = vdw_cutoff self.rthr = vdw_cutoff
self.cnthr = cn_cutoff self.cnthr = cn_cutoff
self.damp_name = damping_type.lower() self.damp_name = damping_type.lower()
self.func_name = functional_name.lower() self.func_name = functional_name.lower()
if self.damp_name not in ['damp_bj', 'damp_zero']: if self.damp_name not in ['damp_bj', 'damp_zero']:
raise ValueError('Error: Invalid damping type.') raise ValueError('Error: Invalid damping type.')
self._lib = _load('pair_d3') self._lib = _load('pair_d3')
self._lib.pair_init.restype = ctypes.POINTER(PairD3) self._lib.pair_init.restype = ctypes.POINTER(PairD3)
self.pair = self._lib.pair_init() self.pair = self._lib.pair_init()
self._lib.pair_set_atom.argtypes = [ self._lib.pair_set_atom.argtypes = [
ctypes.POINTER(PairD3), # PairD3* pair ctypes.POINTER(PairD3), # PairD3* pair
ctypes.c_int, # int natoms ctypes.c_int, # int natoms
ctypes.c_int, # int ntypes ctypes.c_int, # int ntypes
ctypes.POINTER(ctypes.c_int), # int* types ctypes.POINTER(ctypes.c_int), # int* types
ctypes.POINTER(ctypes.c_double), # double* x ctypes.POINTER(ctypes.c_double), # double* x
] ]
self._lib.pair_set_atom.restype = None self._lib.pair_set_atom.restype = None
self._lib.pair_set_domain.argtypes = [ self._lib.pair_set_domain.argtypes = [
ctypes.POINTER(PairD3), # PairD3* pair ctypes.POINTER(PairD3), # PairD3* pair
ctypes.c_int, # int xperiodic ctypes.c_int, # int xperiodic
ctypes.c_int, # int yperiodic ctypes.c_int, # int yperiodic
ctypes.c_int, # int zperiodic ctypes.c_int, # int zperiodic
ctypes.POINTER(ctypes.c_double), # double* boxlo ctypes.POINTER(ctypes.c_double), # double* boxlo
ctypes.POINTER(ctypes.c_double), # double* boxhi ctypes.POINTER(ctypes.c_double), # double* boxhi
ctypes.c_double, # double xy ctypes.c_double, # double xy
ctypes.c_double, # double xz ctypes.c_double, # double xz
ctypes.c_double, # double yz ctypes.c_double, # double yz
] ]
self._lib.pair_set_domain.restype = None self._lib.pair_set_domain.restype = None
self._lib.pair_run_settings.argtypes = [ self._lib.pair_run_settings.argtypes = [
ctypes.POINTER(PairD3), # PairD3* pair ctypes.POINTER(PairD3), # PairD3* pair
ctypes.c_double, # double rthr ctypes.c_double, # double rthr
ctypes.c_double, # double cnthr ctypes.c_double, # double cnthr
ctypes.c_char_p, # const char* damp_name ctypes.c_char_p, # const char* damp_name
ctypes.c_char_p, # const char* func_name ctypes.c_char_p, # const char* func_name
] ]
self._lib.pair_run_settings.restype = None self._lib.pair_run_settings.restype = None
self._lib.pair_run_coeff.argtypes = [ self._lib.pair_run_coeff.argtypes = [
ctypes.POINTER(PairD3), # PairD3* pair ctypes.POINTER(PairD3), # PairD3* pair
ctypes.POINTER(ctypes.c_int), # int* atomic_numbers ctypes.POINTER(ctypes.c_int), # int* atomic_numbers
] ]
self._lib.pair_run_coeff.restype = None self._lib.pair_run_coeff.restype = None
self._lib.pair_run_compute.argtypes = [ctypes.POINTER(PairD3)] self._lib.pair_run_compute.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_run_compute.restype = None self._lib.pair_run_compute.restype = None
self._lib.pair_get_energy.argtypes = [ctypes.POINTER(PairD3)] self._lib.pair_get_energy.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_get_energy.restype = ctypes.c_double self._lib.pair_get_energy.restype = ctypes.c_double
self._lib.pair_get_force.argtypes = [ctypes.POINTER(PairD3)] self._lib.pair_get_force.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_get_force.restype = ctypes.POINTER(ctypes.c_double) self._lib.pair_get_force.restype = ctypes.POINTER(ctypes.c_double)
self._lib.pair_get_stress.argtypes = [ctypes.POINTER(PairD3)] self._lib.pair_get_stress.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_get_stress.restype = ctypes.POINTER(ctypes.c_double * 6) self._lib.pair_get_stress.restype = ctypes.POINTER(ctypes.c_double * 6)
self._lib.pair_fin.argtypes = [ctypes.POINTER(PairD3)] self._lib.pair_fin.argtypes = [ctypes.POINTER(PairD3)]
self._lib.pair_fin.restype = None self._lib.pair_fin.restype = None
def _idx_to_numbers(self, Z_of_atoms): def _idx_to_numbers(self, Z_of_atoms):
unique_numbers = list(dict.fromkeys(Z_of_atoms)) unique_numbers = list(dict.fromkeys(Z_of_atoms))
return unique_numbers return unique_numbers
def _idx_to_types(self, Z_of_atoms): def _idx_to_types(self, Z_of_atoms):
unique_numbers = list(dict.fromkeys(Z_of_atoms)) unique_numbers = list(dict.fromkeys(Z_of_atoms))
mapping = {num: idx + 1 for idx, num in enumerate(unique_numbers)} mapping = {num: idx + 1 for idx, num in enumerate(unique_numbers)}
atom_types = [mapping[num] for num in Z_of_atoms] atom_types = [mapping[num] for num in Z_of_atoms]
return atom_types return atom_types
def _convert_domain_ase2lammps(self, cell): def _convert_domain_ase2lammps(self, cell):
qtrans, ltrans = np.linalg.qr(cell.T, mode='complete') qtrans, ltrans = np.linalg.qr(cell.T, mode='complete')
lammps_cell = ltrans.T lammps_cell = ltrans.T
signs = np.sign(np.diag(lammps_cell)) signs = np.sign(np.diag(lammps_cell))
lammps_cell = lammps_cell * signs lammps_cell = lammps_cell * signs
qtrans = qtrans * signs qtrans = qtrans * signs
lammps_cell = lammps_cell[(0, 1, 2, 1, 2, 2), (0, 1, 2, 0, 0, 1)] lammps_cell = lammps_cell[(0, 1, 2, 1, 2, 2), (0, 1, 2, 0, 0, 1)]
rotator = qtrans.T rotator = qtrans.T
return lammps_cell, rotator return lammps_cell, rotator
def _stress2tensor(self, stress): def _stress2tensor(self, stress):
tensor = np.array( tensor = np.array(
[ [
[stress[0], stress[3], stress[4]], [stress[0], stress[3], stress[4]],
[stress[3], stress[1], stress[5]], [stress[3], stress[1], stress[5]],
[stress[4], stress[5], stress[2]], [stress[4], stress[5], stress[2]],
] ]
) )
return tensor return tensor
def _tensor2stress(self, tensor): def _tensor2stress(self, tensor):
stress = -np.array( stress = -np.array(
[ [
tensor[0, 0], tensor[0, 0],
tensor[1, 1], tensor[1, 1],
tensor[2, 2], tensor[2, 2],
tensor[1, 2], tensor[1, 2],
tensor[0, 2], tensor[0, 2],
tensor[0, 1], tensor[0, 1],
] ]
) )
return stress return stress
def calculate(self, atoms=None, properties=None, system_changes=all_changes): def calculate(self, atoms=None, properties=None, system_changes=all_changes):
Calculator.calculate(self, atoms, properties, system_changes) Calculator.calculate(self, atoms, properties, system_changes)
if atoms is None: if atoms is None:
raise ValueError('No atoms to evaluate') raise ValueError('No atoms to evaluate')
if atoms.get_cell().sum() == 0: if atoms.get_cell().sum() == 0:
print( print(
'Warning: D3Calculator requires a cell.\n' 'Warning: D3Calculator requires a cell.\n'
'Warning: An orthogonal cell large enough is generated.' 'Warning: An orthogonal cell large enough is generated.'
) )
positions = atoms.get_positions() positions = atoms.get_positions()
min_pos = positions.min(axis=0) min_pos = positions.min(axis=0)
max_pos = positions.max(axis=0) max_pos = positions.max(axis=0)
max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726
cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin
cell = np.eye(3) * cell_lengths cell = np.eye(3) * cell_lengths
atoms.set_cell(cell) atoms.set_cell(cell)
atoms.set_pbc([True, True, True]) # for minus positions atoms.set_pbc([True, True, True]) # for minus positions
cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell())
Z_of_atoms = atoms.get_atomic_numbers() Z_of_atoms = atoms.get_atomic_numbers()
natoms = len(atoms) natoms = len(atoms)
ntypes = len(set(Z_of_atoms)) ntypes = len(set(Z_of_atoms))
types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms))
positions = atoms.get_positions() @ rotator.T positions = atoms.get_positions() @ rotator.T
x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten())
atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms))
boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0)
boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2])
xy = cell[3] xy = cell[3]
xz = cell[4] xz = cell[4]
yz = cell[5] yz = cell[5]
xperiodic, yperiodic, zperiodic = atoms.get_pbc() xperiodic, yperiodic, zperiodic = atoms.get_pbc()
lib = self._lib lib = self._lib
assert lib is not None assert lib is not None
lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat)
xperiodic = xperiodic.astype(int) xperiodic = xperiodic.astype(int)
yperiodic = yperiodic.astype(int) yperiodic = yperiodic.astype(int)
zperiodic = zperiodic.astype(int) zperiodic = zperiodic.astype(int)
lib.pair_set_domain( lib.pair_set_domain(
self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz
) )
lib.pair_run_settings( lib.pair_run_settings(
self.pair, self.pair,
self.rthr, self.rthr,
self.cnthr, self.cnthr,
self.damp_name.encode('utf-8'), self.damp_name.encode('utf-8'),
self.func_name.encode('utf-8'), self.func_name.encode('utf-8'),
) )
lib.pair_run_coeff(self.pair, atomic_numbers) lib.pair_run_coeff(self.pair, atomic_numbers)
lib.pair_run_compute(self.pair) lib.pair_run_compute(self.pair)
result_E = lib.pair_get_energy(self.pair) result_E = lib.pair_get_energy(self.pair)
result_F_ptr = lib.pair_get_force(self.pair) result_F_ptr = lib.pair_get_force(self.pair)
result_F_size = natoms * 3 result_F_size = natoms * 3
result_F = np.ctypeslib.as_array( result_F = np.ctypeslib.as_array(
result_F_ptr, shape=(result_F_size,) result_F_ptr, shape=(result_F_size,)
).reshape((natoms, 3)) ).reshape((natoms, 3))
result_F = np.array(result_F) result_F = np.array(result_F)
result_F = result_F @ rotator result_F = result_F @ rotator
result_S = lib.pair_get_stress(self.pair) result_S = lib.pair_get_stress(self.pair)
result_S = np.array(result_S.contents) result_S = np.array(result_S.contents)
result_S = ( result_S = (
self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator)
/ atoms.get_volume() / atoms.get_volume()
) )
self.results = { self.results = {
'free_energy': result_E, 'free_energy': result_E,
'energy': result_E, 'energy': result_E,
'forces': result_F, 'forces': result_F,
'stress': result_S, 'stress': result_S,
} }
def predict_one(self, atoms): def predict_one(self, atoms):
atoms = atoms.copy() atoms = atoms.copy()
if atoms is None: if atoms is None:
raise ValueError('No atoms to evaluate') raise ValueError('No atoms to evaluate')
if atoms.get_cell().sum() == 0: if atoms.get_cell().sum() == 0:
print( print(
'Warning: D3Calculator requires a cell.\n' 'Warning: D3Calculator requires a cell.\n'
'Warning: An orthogonal cell large enough is generated.' 'Warning: An orthogonal cell large enough is generated.'
) )
positions = atoms.get_positions() positions = atoms.get_positions()
min_pos = positions.min(axis=0) min_pos = positions.min(axis=0)
max_pos = positions.max(axis=0) max_pos = positions.max(axis=0)
max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726
cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin
cell = np.eye(3) * cell_lengths cell = np.eye(3) * cell_lengths
atoms.set_cell(cell) atoms.set_cell(cell)
atoms.set_pbc([True, True, True]) # for minus positions atoms.set_pbc([True, True, True]) # for minus positions
cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell())
Z_of_atoms = atoms.get_atomic_numbers() Z_of_atoms = atoms.get_atomic_numbers()
natoms = len(atoms) natoms = len(atoms)
ntypes = len(set(Z_of_atoms)) ntypes = len(set(Z_of_atoms))
types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms))
positions = atoms.get_positions() @ rotator.T positions = atoms.get_positions() @ rotator.T
x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten())
atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms))
boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0)
boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2])
xy = cell[3] xy = cell[3]
xz = cell[4] xz = cell[4]
yz = cell[5] yz = cell[5]
xperiodic, yperiodic, zperiodic = atoms.get_pbc() xperiodic, yperiodic, zperiodic = atoms.get_pbc()
lib = self._lib lib = self._lib
assert lib is not None assert lib is not None
lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat)
xperiodic = xperiodic.astype(int) xperiodic = xperiodic.astype(int)
yperiodic = yperiodic.astype(int) yperiodic = yperiodic.astype(int)
zperiodic = zperiodic.astype(int) zperiodic = zperiodic.astype(int)
lib.pair_set_domain( lib.pair_set_domain(
self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz
) )
lib.pair_run_settings( lib.pair_run_settings(
self.pair, self.pair,
self.rthr, self.rthr,
self.cnthr, self.cnthr,
self.damp_name.encode('utf-8'), self.damp_name.encode('utf-8'),
self.func_name.encode('utf-8'), self.func_name.encode('utf-8'),
) )
lib.pair_run_coeff(self.pair, atomic_numbers) lib.pair_run_coeff(self.pair, atomic_numbers)
lib.pair_run_compute(self.pair) lib.pair_run_compute(self.pair)
result_E = lib.pair_get_energy(self.pair) result_E = lib.pair_get_energy(self.pair)
result_F_ptr = lib.pair_get_force(self.pair) result_F_ptr = lib.pair_get_force(self.pair)
result_F_size = natoms * 3 result_F_size = natoms * 3
result_F = np.ctypeslib.as_array( result_F = np.ctypeslib.as_array(
result_F_ptr, shape=(result_F_size,) result_F_ptr, shape=(result_F_size,)
).reshape((natoms, 3)) ).reshape((natoms, 3))
result_F = np.array(result_F) result_F = np.array(result_F)
result_F = result_F @ rotator result_F = result_F @ rotator
result_S = lib.pair_get_stress(self.pair) result_S = lib.pair_get_stress(self.pair)
result_S = np.array(result_S.contents) result_S = np.array(result_S.contents)
result_S = ( result_S = (
self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator)
/ atoms.get_volume() / atoms.get_volume()
) )
prediction = { prediction = {
'free_energy': float(result_E), 'free_energy': float(result_E),
'energy': float(result_E), 'energy': float(result_E),
'forces': result_F.copy(), 'forces': result_F.copy(),
'stress': result_S.copy(), 'stress': result_S.copy(),
} }
return prediction return prediction
def __del__(self): def __del__(self):
if self._lib is not None: if self._lib is not None:
self._lib.pair_fin(self.pair) self._lib.pair_fin(self.pair)
self._lib = None self._lib = None
self.pair = None self.pair = None
import os import os
import pathlib import pathlib
import uuid import uuid
import warnings import warnings
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import pandas as pd import pandas as pd
from packaging.version import Version from packaging.version import Version
from torch import Tensor from torch import Tensor
from torch import load as torch_load from torch import load as torch_load
import sevenn import sevenn
import sevenn._const as consts import sevenn._const as consts
import sevenn._keys as KEY import sevenn._keys as KEY
import sevenn.scripts.backward_compatibility as compat import sevenn.scripts.backward_compatibility as compat
from sevenn import model_build from sevenn import model_build
from sevenn.nn.scale import get_resolved_shift_scale from sevenn.nn.scale import get_resolved_shift_scale
from sevenn.nn.sequential import AtomGraphSequential from sevenn.nn.sequential import AtomGraphSequential
def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6):
import numpy as np import numpy as np
def acl(a, b, rtol=rtol, atol=atol): def acl(a, b, rtol=rtol, atol=atol):
return np.allclose(a, b, rtol=rtol, atol=atol) return np.allclose(a, b, rtol=rtol, atol=atol)
assert len(atoms1) == len(atoms2) assert len(atoms1) == len(atoms2)
assert acl(atoms1.get_cell(), atoms2.get_cell()) assert acl(atoms1.get_cell(), atoms2.get_cell())
assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy())
assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10)
assert acl( assert acl(
atoms1.get_stress(voigt=False), atoms1.get_stress(voigt=False),
atoms2.get_stress(voigt=False), atoms2.get_stress(voigt=False),
rtol * 10, rtol * 10,
atol * 10, atol * 10,
) )
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())
def copy_state_dict(state_dict) -> dict: def copy_state_dict(state_dict) -> dict:
if isinstance(state_dict, dict): if isinstance(state_dict, dict):
return {key: copy_state_dict(value) for key, value in state_dict.items()} return {key: copy_state_dict(value) for key, value in state_dict.items()}
elif isinstance(state_dict, list): elif isinstance(state_dict, list):
return [copy_state_dict(item) for item in state_dict] # type: ignore return [copy_state_dict(item) for item in state_dict] # type: ignore
elif isinstance(state_dict, Tensor): elif isinstance(state_dict, Tensor):
return state_dict.clone() # type: ignore return state_dict.clone() # type: ignore
else: else:
# For non-tensor values (e.g., scalars, None), return as-is # For non-tensor values (e.g., scalars, None), return as-is
return state_dict return state_dict
def _config_cp_routine(config): def _config_cp_routine(config):
cp_ver = Version(config.get('version', None)) cp_ver = Version(config.get('version', None))
this_ver = Version(sevenn.__version__) this_ver = Version(sevenn.__version__)
if cp_ver > this_ver: if cp_ver > this_ver:
warnings.warn(f'The checkpoint version ({cp_ver}) is newer than this source' warnings.warn(f'The checkpoint version ({cp_ver}) is newer than this source'
f'({this_ver}). This may cause unexpected behaviors') f'({this_ver}). This may cause unexpected behaviors')
defaults = {**consts.model_defaults(config)} defaults = {**consts.model_defaults(config)}
config = compat.patch_old_config(config) # type: ignore config = compat.patch_old_config(config) # type: ignore
scaler = model_build.init_shift_scale(config) scaler = model_build.init_shift_scale(config)
shift, scale = get_resolved_shift_scale( shift, scale = get_resolved_shift_scale(
scaler, config.get(KEY.TYPE_MAP), config.get(KEY.MODAL_MAP, None) scaler, config.get(KEY.TYPE_MAP), config.get(KEY.MODAL_MAP, None)
) )
config['shift'] = shift config['shift'] = shift
config['scale'] = scale config['scale'] = scale
for k, v in defaults.items(): for k, v in defaults.items():
if k in config: if k in config:
continue continue
if os.getenv('SEVENN_DEBUG', False): if os.getenv('SEVENN_DEBUG', False):
warnings.warn(f'{k} not in config, use default value {v}', UserWarning) warnings.warn(f'{k} not in config, use default value {v}', UserWarning)
config[k] = v config[k] = v
for k, v in config.items(): for k, v in config.items():
if isinstance(v, Tensor): if isinstance(v, Tensor):
config[k] = v.cpu() config[k] = v.cpu()
return config return config
def _convert_e3nn_and_cueq(stct_src, stct_dst, src_config, from_cueq): def _convert_e3nn_and_cueq(stct_src, stct_dst, src_config, from_cueq):
""" """
manually check keys and assert if something unexpected happens manually check keys and assert if something unexpected happens
""" """
n_layer = src_config['num_convolution_layer'] n_layer = src_config['num_convolution_layer']
linear_module_names = [ linear_module_names = [
'onehot_to_feature_x', 'onehot_to_feature_x',
'reduce_input_to_hidden', 'reduce_input_to_hidden',
'reduce_hidden_to_energy', 'reduce_hidden_to_energy',
] ]
convolution_module_names = [] convolution_module_names = []
fc_tensor_product_module_names = [] fc_tensor_product_module_names = []
for i in range(n_layer): for i in range(n_layer):
linear_module_names.append(f'{i}_self_interaction_1') linear_module_names.append(f'{i}_self_interaction_1')
linear_module_names.append(f'{i}_self_interaction_2') linear_module_names.append(f'{i}_self_interaction_2')
if src_config.get(KEY.SELF_CONNECTION_TYPE) == 'linear': if src_config.get(KEY.SELF_CONNECTION_TYPE) == 'linear':
linear_module_names.append(f'{i}_self_connection_intro') linear_module_names.append(f'{i}_self_connection_intro')
elif src_config.get(KEY.SELF_CONNECTION_TYPE) == 'nequip': elif src_config.get(KEY.SELF_CONNECTION_TYPE) == 'nequip':
fc_tensor_product_module_names.append(f'{i}_self_connection_intro') fc_tensor_product_module_names.append(f'{i}_self_connection_intro')
convolution_module_names.append(f'{i}_convolution') convolution_module_names.append(f'{i}_convolution')
# Rule: those keys can be safely ignored before state dict load, # Rule: those keys can be safely ignored before state dict load,
# except for linear.bias. This should be aborted in advance to # except for linear.bias. This should be aborted in advance to
# this function. Others are not parameters but constants. # this function. Others are not parameters but constants.
cue_only_linear_followers = ['linear.f.tp.f_fx.module.c'] cue_only_linear_followers = ['linear.f.tp.f_fx.module.c']
e3nn_only_linear_followers = ['linear.bias', 'linear.output_mask'] e3nn_only_linear_followers = ['linear.bias', 'linear.output_mask']
ignores_in_linear = cue_only_linear_followers + e3nn_only_linear_followers ignores_in_linear = cue_only_linear_followers + e3nn_only_linear_followers
cue_only_conv_followers = [ cue_only_conv_followers = [
'convolution.f.tp.f_fx.module.c', 'convolution.f.tp.f_fx.module.c',
'convolution.f.tp.module.module.f.module.module._f.data', 'convolution.f.tp.module.module.f.module.module._f.data',
] ]
e3nn_only_conv_followers = [ e3nn_only_conv_followers = [
'convolution._compiled_main_left_right._w3j', 'convolution._compiled_main_left_right._w3j',
'convolution.weight', 'convolution.weight',
'convolution.output_mask', 'convolution.output_mask',
] ]
ignores_in_conv = cue_only_conv_followers + e3nn_only_conv_followers ignores_in_conv = cue_only_conv_followers + e3nn_only_conv_followers
cue_only_fc_followers = ['fc_tensor_product.f.tp.f_fx.module.c'] cue_only_fc_followers = ['fc_tensor_product.f.tp.f_fx.module.c']
e3nn_only_fc_followers = [ e3nn_only_fc_followers = [
'fc_tensor_product.output_mask', 'fc_tensor_product.output_mask',
] ]
ignores_in_fc = cue_only_fc_followers + e3nn_only_fc_followers ignores_in_fc = cue_only_fc_followers + e3nn_only_fc_followers
updated_keys = [] updated_keys = []
for k, v in stct_src.items(): for k, v in stct_src.items():
module_name = k.split('.')[0] module_name = k.split('.')[0]
flag = False flag = False
if module_name in linear_module_names: if module_name in linear_module_names:
for ignore in ignores_in_linear: for ignore in ignores_in_linear:
if '.'.join([module_name, ignore]) in k: if '.'.join([module_name, ignore]) in k:
flag = True flag = True
break break
if not flag and k == '.'.join([module_name, 'linear.weight']): if not flag and k == '.'.join([module_name, 'linear.weight']):
updated_keys.append(k) updated_keys.append(k)
stct_dst[k] = v.clone().reshape(stct_dst[k].shape) stct_dst[k] = v.clone().reshape(stct_dst[k].shape)
flag = True flag = True
assert flag, f'Unexpected key from linear: {k}' assert flag, f'Unexpected key from linear: {k}'
elif module_name in convolution_module_names: elif module_name in convolution_module_names:
for ignore in ignores_in_conv: for ignore in ignores_in_conv:
if '.'.join([module_name, ignore]) in k: if '.'.join([module_name, ignore]) in k:
flag = True flag = True
break break
if not flag and ( if not flag and (
k.startswith(f'{module_name}.weight_nn') k.startswith(f'{module_name}.weight_nn')
or k == '.'.join([module_name, 'denominator']) or k == '.'.join([module_name, 'denominator'])
): ):
updated_keys.append(k) updated_keys.append(k)
stct_dst[k] = v.clone().reshape(stct_dst[k].shape) stct_dst[k] = v.clone().reshape(stct_dst[k].shape)
flag = True flag = True
assert flag, f'Unexpected key from linear: {k}' assert flag, f'Unexpected key from linear: {k}'
elif module_name in fc_tensor_product_module_names: elif module_name in fc_tensor_product_module_names:
for ignore in ignores_in_fc: for ignore in ignores_in_fc:
if '.'.join([module_name, ignore]) in k: if '.'.join([module_name, ignore]) in k:
flag = True flag = True
break break
if not flag and k == '.'.join([module_name, 'fc_tensor_product.weight']): if not flag and k == '.'.join([module_name, 'fc_tensor_product.weight']):
updated_keys.append(k) updated_keys.append(k)
stct_dst[k] = v.clone().reshape(stct_dst[k].shape) stct_dst[k] = v.clone().reshape(stct_dst[k].shape)
flag = True flag = True
assert flag, f'Unexpected key from fc tensor product: {k}' assert flag, f'Unexpected key from fc tensor product: {k}'
else: else:
# assert k in stct_dst # assert k in stct_dst
updated_keys.append(k) updated_keys.append(k)
stct_dst[k] = v.clone().reshape(stct_dst[k].shape) stct_dst[k] = v.clone().reshape(stct_dst[k].shape)
return stct_dst return stct_dst
class SevenNetCheckpoint: class SevenNetCheckpoint:
""" """
Tool box for checkpoint processed from SevenNet. Tool box for checkpoint processed from SevenNet.
""" """
def __init__(self, checkpoint_path: Union[pathlib.Path, str]): def __init__(self, checkpoint_path: Union[pathlib.Path, str]):
self._checkpoint_path = os.path.abspath(checkpoint_path) self._checkpoint_path = os.path.abspath(checkpoint_path)
self._config = None self._config = None
self._epoch = None self._epoch = None
self._model_state_dict = None self._model_state_dict = None
self._optimizer_state_dict = None self._optimizer_state_dict = None
self._scheduler_state_dict = None self._scheduler_state_dict = None
self._hash = None self._hash = None
self._time = None self._time = None
self._loaded = False self._loaded = False
def __repr__(self) -> str: def __repr__(self) -> str:
cfg = self.config # just alias cfg = self.config # just alias
if len(cfg) == 0: if len(cfg) == 0:
return '' return ''
dct = { dct = {
'Sevennet version': cfg.get('version', 'Not found'), 'Sevennet version': cfg.get('version', 'Not found'),
'When': self.time, 'When': self.time,
'Hash': self.hash, 'Hash': self.hash,
'Cutoff': cfg.get('cutoff'), 'Cutoff': cfg.get('cutoff'),
'Channel': cfg.get('channel'), 'Channel': cfg.get('channel'),
'Lmax': cfg.get('lmax'), 'Lmax': cfg.get('lmax'),
'Group (parity)': 'O3' if cfg.get('is_parity') else 'SO3', 'Group (parity)': 'O3' if cfg.get('is_parity') else 'SO3',
'Interaction layers': cfg.get('num_convolution_layer'), 'Interaction layers': cfg.get('num_convolution_layer'),
'Self connection type': cfg.get('self_connection_type', 'nequip'), 'Self connection type': cfg.get('self_connection_type', 'nequip'),
'Last epoch': self.epoch, 'Last epoch': self.epoch,
'Elements': len(cfg.get('chemical_species', [])), 'Elements': len(cfg.get('chemical_species', [])),
} }
if cfg.get('use_modality', False): if cfg.get('use_modality', False):
dct['Modality'] = ', '.join(list(cfg.get('_modal_map', {}).keys())) dct['Modality'] = ', '.join(list(cfg.get('_modal_map', {}).keys()))
df = pd.DataFrame.from_dict([dct]).T # type: ignore df = pd.DataFrame.from_dict([dct]).T # type: ignore
df.columns = [''] df.columns = ['']
return df.to_string() return df.to_string()
@property @property
def checkpoint_path(self) -> str: def checkpoint_path(self) -> str:
return str(self._checkpoint_path) return str(self._checkpoint_path)
@property @property
def config(self) -> Dict[str, Any]: def config(self) -> Dict[str, Any]:
if not self._loaded: if not self._loaded:
self._load() self._load()
assert isinstance(self._config, dict) assert isinstance(self._config, dict)
return deepcopy(self._config) return deepcopy(self._config)
@property @property
def model_state_dict(self) -> Dict[str, Any]: def model_state_dict(self) -> Dict[str, Any]:
if not self._loaded: if not self._loaded:
self._load() self._load()
assert isinstance(self._model_state_dict, dict) assert isinstance(self._model_state_dict, dict)
return copy_state_dict(self._model_state_dict) return copy_state_dict(self._model_state_dict)
@property @property
def optimizer_state_dict(self) -> Dict[str, Any]: def optimizer_state_dict(self) -> Dict[str, Any]:
if not self._loaded: if not self._loaded:
self._load() self._load()
assert isinstance(self._optimizer_state_dict, dict) assert isinstance(self._optimizer_state_dict, dict)
return copy_state_dict(self._optimizer_state_dict) return copy_state_dict(self._optimizer_state_dict)
@property @property
def scheduler_state_dict(self) -> Dict[str, Any]: def scheduler_state_dict(self) -> Dict[str, Any]:
if not self._loaded: if not self._loaded:
self._load() self._load()
assert isinstance(self._scheduler_state_dict, dict) assert isinstance(self._scheduler_state_dict, dict)
return copy_state_dict(self._scheduler_state_dict) return copy_state_dict(self._scheduler_state_dict)
@property @property
def epoch(self) -> Optional[int]: def epoch(self) -> Optional[int]:
if not self._loaded: if not self._loaded:
self._load() self._load()
return self._epoch return self._epoch
@property @property
def time(self) -> str: def time(self) -> str:
if not self._loaded: if not self._loaded:
self._load() self._load()
assert isinstance(self._time, str) assert isinstance(self._time, str)
return self._time return self._time
@property @property
def hash(self) -> str: def hash(self) -> str:
if not self._loaded: if not self._loaded:
self._load() self._load()
assert isinstance(self._hash, str) assert isinstance(self._hash, str)
return self._hash return self._hash
def _load(self) -> None: def _load(self) -> None:
assert not self._loaded assert not self._loaded
cp_path = self.checkpoint_path # just alias cp_path = self.checkpoint_path # just alias
cp = torch_load(cp_path, weights_only=False, map_location='cpu') cp = torch_load(cp_path, weights_only=False, map_location='cpu')
self._config_original = cp.get('config', {}) self._config_original = cp.get('config', {})
self._model_state_dict = cp.get('model_state_dict', {}) self._model_state_dict = cp.get('model_state_dict', {})
self._optimizer_state_dict = cp.get('optimizer_state_dict', {}) self._optimizer_state_dict = cp.get('optimizer_state_dict', {})
self._scheduler_state_dict = cp.get('scheduler_state_dict', {}) self._scheduler_state_dict = cp.get('scheduler_state_dict', {})
self._epoch = cp.get('epoch', None) self._epoch = cp.get('epoch', None)
self._time = cp.get('time', 'Not found') self._time = cp.get('time', 'Not found')
self._hash = cp.get('hash', 'Not found') self._hash = cp.get('hash', 'Not found')
if len(self._config_original) == 0: if len(self._config_original) == 0:
warnings.warn(f'config is not found from {cp_path}') warnings.warn(f'config is not found from {cp_path}')
self._config = {} self._config = {}
else: else:
self._config = _config_cp_routine(self._config_original) self._config = _config_cp_routine(self._config_original)
if len(self._model_state_dict) == 0: if len(self._model_state_dict) == 0:
warnings.warn(f'model_state_dict is not found from {cp_path}') warnings.warn(f'model_state_dict is not found from {cp_path}')
self._loaded = True self._loaded = True
def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential: def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential:
from .model_build import build_E3_equivariant_model from .model_build import build_E3_equivariant_model
use_cue = not backend or backend.lower() in ['cue', 'cueq'] use_cue = not backend or backend.lower() in ['cue', 'cueq']
try: try:
cp_using_cue = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use'] cp_using_cue = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use']
except KeyError: except KeyError:
cp_using_cue = False cp_using_cue = False
if (not backend) or (use_cue == cp_using_cue): if (not backend) or (use_cue == cp_using_cue):
# backend not given, or checkpoint backend is same as requested # backend not given, or checkpoint backend is same as requested
model = build_E3_equivariant_model(self.config) model = build_E3_equivariant_model(self.config)
state_dict = compat.patch_state_dict_if_old( state_dict = compat.patch_state_dict_if_old(
self.model_state_dict, self.config, model self.model_state_dict, self.config, model
) )
else: else:
cfg_new = self.config cfg_new = self.config
cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': use_cue} cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': use_cue}
model = build_E3_equivariant_model(cfg_new) model = build_E3_equivariant_model(cfg_new)
stct_src = compat.patch_state_dict_if_old( stct_src = compat.patch_state_dict_if_old(
self.model_state_dict, self.config, model self.model_state_dict, self.config, model
) )
state_dict = _convert_e3nn_and_cueq( state_dict = _convert_e3nn_and_cueq(
stct_src, model.state_dict(), self.config, from_cueq=cp_using_cue stct_src, model.state_dict(), self.config, from_cueq=cp_using_cue
) )
missing, not_used = model.load_state_dict(state_dict, strict=False) missing, not_used = model.load_state_dict(state_dict, strict=False)
if len(not_used) > 0: if len(not_used) > 0:
warnings.warn(f'Some keys are not used: {not_used}', UserWarning) warnings.warn(f'Some keys are not used: {not_used}', UserWarning)
assert len(missing) == 0, f'Missing keys: {missing}' assert len(missing) == 0, f'Missing keys: {missing}'
return model return model
def yaml_dict(self, mode: str) -> dict: def yaml_dict(self, mode: str) -> dict:
""" """
Return dict for input.yaml from checkpoint config Return dict for input.yaml from checkpoint config
Dataset paths and statistic values are removed intentionally Dataset paths and statistic values are removed intentionally
""" """
if mode not in ['reproduce', 'continue', 'continue_modal']: if mode not in ['reproduce', 'continue', 'continue_modal']:
raise ValueError(f'Unknown mode: {mode}') raise ValueError(f'Unknown mode: {mode}')
ignore = [ ignore = [
'when', 'when',
KEY.DDP_BACKEND, KEY.DDP_BACKEND,
KEY.LOCAL_RANK, KEY.LOCAL_RANK,
KEY.IS_DDP, KEY.IS_DDP,
KEY.DEVICE, KEY.DEVICE,
KEY.MODEL_TYPE, KEY.MODEL_TYPE,
KEY.SHIFT, KEY.SHIFT,
KEY.SCALE, KEY.SCALE,
KEY.CONV_DENOMINATOR, KEY.CONV_DENOMINATOR,
KEY.SAVE_DATASET, KEY.SAVE_DATASET,
KEY.SAVE_BY_LABEL, KEY.SAVE_BY_LABEL,
KEY.SAVE_BY_TRAIN_VALID, KEY.SAVE_BY_TRAIN_VALID,
KEY.CONTINUE, KEY.CONTINUE,
KEY.LOAD_DATASET, # old KEY.LOAD_DATASET, # old
] ]
cfg = self.config cfg = self.config
len_atoms = len(cfg[KEY.TYPE_MAP]) len_atoms = len(cfg[KEY.TYPE_MAP])
world_size = cfg.pop(KEY.WORLD_SIZE, 1) world_size = cfg.pop(KEY.WORLD_SIZE, 1)
cfg[KEY.BATCH_SIZE] = cfg[KEY.BATCH_SIZE] * world_size cfg[KEY.BATCH_SIZE] = cfg[KEY.BATCH_SIZE] * world_size
cfg[KEY.LOAD_TRAINSET] = '**path_to_training_set**' cfg[KEY.LOAD_TRAINSET] = '**path_to_training_set**'
major, minor, _ = cfg.pop('version', '0.0.0').split('.')[:3] major, minor, _ = cfg.pop('version', '0.0.0').split('.')[:3]
if int(major) == 0 and int(minor) <= 9: if int(major) == 0 and int(minor) <= 9:
warnings.warn('checkpoint version too old, yaml may wrong') warnings.warn('checkpoint version too old, yaml may wrong')
ret = {'model': {}, 'train': {}, 'data': {}} ret = {'model': {}, 'train': {}, 'data': {}}
for k, v in cfg.items(): for k, v in cfg.items():
if k.startswith('_') or k in ignore or k.endswith('set_path'): if k.startswith('_') or k in ignore or k.endswith('set_path'):
continue continue
if k in consts.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG: if k in consts.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG:
ret['model'][k] = v ret['model'][k] = v
elif k in consts.DEFAULT_TRAINING_CONFIG: elif k in consts.DEFAULT_TRAINING_CONFIG:
ret['train'][k] = v ret['train'][k] = v
elif k in consts.DEFAULT_DATA_CONFIG: elif k in consts.DEFAULT_DATA_CONFIG:
ret['data'][k] = v ret['data'][k] = v
ret['model'][KEY.CHEMICAL_SPECIES] = ( ret['model'][KEY.CHEMICAL_SPECIES] = (
'univ' if len_atoms == consts.NUM_UNIV_ELEMENT else 'auto' 'univ' if len_atoms == consts.NUM_UNIV_ELEMENT else 'auto'
) )
ret['data'][KEY.LOAD_TRAINSET] = '**path_to_trainset**' ret['data'][KEY.LOAD_TRAINSET] = '**path_to_trainset**'
ret['data'][KEY.LOAD_VALIDSET] = '**path_to_validset**' ret['data'][KEY.LOAD_VALIDSET] = '**path_to_validset**'
# TODO # TODO
ret['data'][KEY.SHIFT] = '**failed to infer shift, should be set**' ret['data'][KEY.SHIFT] = '**failed to infer shift, should be set**'
ret['data'][KEY.SCALE] = '**failed to infer scale, should be set**' ret['data'][KEY.SCALE] = '**failed to infer scale, should be set**'
if mode.startswith('continue'): if mode.startswith('continue'):
ret['train'].update( ret['train'].update(
{KEY.CONTINUE: {KEY.CHECKPOINT: self.checkpoint_path}} {KEY.CONTINUE: {KEY.CHECKPOINT: self.checkpoint_path}}
) )
modal_names = None modal_names = None
if mode == 'continue_modal' and not cfg.get(KEY.USE_MODALITY, False): if mode == 'continue_modal' and not cfg.get(KEY.USE_MODALITY, False):
ret['train'][KEY.USE_MODALITY] = True ret['train'][KEY.USE_MODALITY] = True
# suggest defaults # suggest defaults
ret['model'][KEY.USE_MODAL_NODE_EMBEDDING] = False ret['model'][KEY.USE_MODAL_NODE_EMBEDDING] = False
ret['model'][KEY.USE_MODAL_SELF_INTER_INTRO] = True ret['model'][KEY.USE_MODAL_SELF_INTER_INTRO] = True
ret['model'][KEY.USE_MODAL_SELF_INTER_OUTRO] = True ret['model'][KEY.USE_MODAL_SELF_INTER_OUTRO] = True
ret['model'][KEY.USE_MODAL_OUTPUT_BLOCK] = True ret['model'][KEY.USE_MODAL_OUTPUT_BLOCK] = True
ret['data'][KEY.USE_MODAL_WISE_SHIFT] = True ret['data'][KEY.USE_MODAL_WISE_SHIFT] = True
ret['data'][KEY.USE_MODAL_WISE_SCALE] = False ret['data'][KEY.USE_MODAL_WISE_SCALE] = False
modal_names = ['my_modal1', 'my_modal2'] modal_names = ['my_modal1', 'my_modal2']
elif cfg.get(KEY.USE_MODALITY, False): elif cfg.get(KEY.USE_MODALITY, False):
modal_names = list(cfg[KEY.MODAL_MAP].keys()) modal_names = list(cfg[KEY.MODAL_MAP].keys())
if modal_names: if modal_names:
ret['data'][KEY.LOAD_TRAINSET] = [ ret['data'][KEY.LOAD_TRAINSET] = [
{'data_modality': mm, 'file_list': [{'file': f'**path_to_{mm}**'}]} {'data_modality': mm, 'file_list': [{'file': f'**path_to_{mm}**'}]}
for mm in modal_names for mm in modal_names
] ]
return ret return ret
def append_modal( def append_modal(
self, self,
dst_config, dst_config,
original_modal_name: str = 'origin', original_modal_name: str = 'origin',
working_dir: str = os.getcwd(), working_dir: str = os.getcwd(),
): ):
""" """ """ """
import sevenn.train.modal_dataset as modal_dataset import sevenn.train.modal_dataset as modal_dataset
from sevenn.model_build import init_shift_scale from sevenn.model_build import init_shift_scale
from sevenn.scripts.convert_model_modality import _append_modal_weight from sevenn.scripts.convert_model_modality import _append_modal_weight
src_config = self.config src_config = self.config
src_has_no_modal = not src_config.get(KEY.USE_MODALITY, False) src_has_no_modal = not src_config.get(KEY.USE_MODALITY, False)
# inherit element things first # inherit element things first
chem_keys = [ chem_keys = [
KEY.TYPE_MAP, KEY.TYPE_MAP,
KEY.NUM_SPECIES, KEY.NUM_SPECIES,
KEY.CHEMICAL_SPECIES, KEY.CHEMICAL_SPECIES,
KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER,
] ]
dst_config.update({k: src_config[k] for k in chem_keys}) dst_config.update({k: src_config[k] for k in chem_keys})
if dst_config[KEY.USE_MODAL_WISE_SHIFT] and ( if dst_config[KEY.USE_MODAL_WISE_SHIFT] and (
KEY.SHIFT not in dst_config or not isinstance(dst_config[KEY.SHIFT], str) KEY.SHIFT not in dst_config or not isinstance(dst_config[KEY.SHIFT], str)
): ):
raise ValueError('To use modal wise shift, keyword shift is required') raise ValueError('To use modal wise shift, keyword shift is required')
if dst_config[KEY.USE_MODAL_WISE_SCALE] and ( if dst_config[KEY.USE_MODAL_WISE_SCALE] and (
KEY.SCALE not in dst_config or not isinstance(dst_config[KEY.SCALE], str) KEY.SCALE not in dst_config or not isinstance(dst_config[KEY.SCALE], str)
): ):
raise ValueError('To use modal wise scale, keyword scale is required') raise ValueError('To use modal wise scale, keyword scale is required')
if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SHIFT]: if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SHIFT]:
dst_config[KEY.SHIFT] = src_config[KEY.SHIFT] dst_config[KEY.SHIFT] = src_config[KEY.SHIFT]
if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SCALE]: if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SCALE]:
dst_config[KEY.SCALE] = src_config[KEY.SCALE] dst_config[KEY.SCALE] = src_config[KEY.SCALE]
# get statistics of given datasets of yaml # get statistics of given datasets of yaml
# dst_config updated # dst_config updated
_ = modal_dataset.from_config(dst_config, working_dir=working_dir) _ = modal_dataset.from_config(dst_config, working_dir=working_dir)
dst_modal_map = dst_config[KEY.MODAL_MAP] dst_modal_map = dst_config[KEY.MODAL_MAP]
found_modal_names = list(dst_modal_map.keys()) found_modal_names = list(dst_modal_map.keys())
if len(found_modal_names) == 0: if len(found_modal_names) == 0:
raise ValueError('No modality is found from config') raise ValueError('No modality is found from config')
# Check difference btw given modals and new modal map # Check difference btw given modals and new modal map
orig_modal_map = src_config.get(KEY.MODAL_MAP, {original_modal_name: 0}) orig_modal_map = src_config.get(KEY.MODAL_MAP, {original_modal_name: 0})
assert isinstance(orig_modal_map, dict) assert isinstance(orig_modal_map, dict)
new_modal_map = orig_modal_map.copy() new_modal_map = orig_modal_map.copy()
for modal_name in found_modal_names: for modal_name in found_modal_names:
if modal_name in orig_modal_map: # duplicate, skipping if modal_name in orig_modal_map: # duplicate, skipping
continue continue
new_modal_map[modal_name] = len(new_modal_map) # assign new new_modal_map[modal_name] = len(new_modal_map) # assign new
print(f'New modals: {list(new_modal_map.keys())}') print(f'New modals: {list(new_modal_map.keys())}')
if src_has_no_modal: if src_has_no_modal:
append_num = len(new_modal_map) append_num = len(new_modal_map)
else: else:
append_num = len(new_modal_map) - len(orig_modal_map) append_num = len(new_modal_map) - len(orig_modal_map)
if append_num == 0: if append_num == 0:
raise ValueError('Nothing to append from checkpoint') raise ValueError('Nothing to append from checkpoint')
dst_config[KEY.NUM_MODALITIES] = len(new_modal_map) dst_config[KEY.NUM_MODALITIES] = len(new_modal_map)
dst_config[KEY.MODAL_MAP] = new_modal_map dst_config[KEY.MODAL_MAP] = new_modal_map
# update dst_config's shift scales based on src_config # update dst_config's shift scales based on src_config
for ss_key, use_mw in ( for ss_key, use_mw in (
(KEY.SHIFT, dst_config[KEY.USE_MODAL_WISE_SHIFT]), (KEY.SHIFT, dst_config[KEY.USE_MODAL_WISE_SHIFT]),
(KEY.SCALE, dst_config[KEY.USE_MODAL_WISE_SCALE]), (KEY.SCALE, dst_config[KEY.USE_MODAL_WISE_SCALE]),
): ):
if not use_mw: # not using mw ss, just assign if not use_mw: # not using mw ss, just assign
assert not isinstance(dst_config[ss_key], dict) assert not isinstance(dst_config[ss_key], dict)
dst_config[ss_key] = src_config[ss_key] dst_config[ss_key] = src_config[ss_key]
elif src_has_no_modal: elif src_has_no_modal:
assert isinstance(dst_config[ss_key], dict) assert isinstance(dst_config[ss_key], dict)
# mw ss, update by dict but use original_modal_name # mw ss, update by dict but use original_modal_name
dst_config[ss_key].update({original_modal_name: src_config[ss_key]}) dst_config[ss_key].update({original_modal_name: src_config[ss_key]})
else: else:
assert isinstance(dst_config[ss_key], dict) assert isinstance(dst_config[ss_key], dict)
# mw ss, update by dict # mw ss, update by dict
dst_config[ss_key].update(src_config[ss_key]) dst_config[ss_key].update(src_config[ss_key])
scaler = init_shift_scale(dst_config) scaler = init_shift_scale(dst_config)
# finally, prepare updated continuable state dict using above # finally, prepare updated continuable state dict using above
orig_model = self.build_model() orig_model = self.build_model()
orig_state_dict = orig_model.state_dict() orig_state_dict = orig_model.state_dict()
new_state_dict = copy_state_dict(orig_state_dict) new_state_dict = copy_state_dict(orig_state_dict)
for stct_key in orig_state_dict: for stct_key in orig_state_dict:
sp = stct_key.split('.') sp = stct_key.split('.')
k, follower = sp[0], '.'.join(sp[1:]) k, follower = sp[0], '.'.join(sp[1:])
if k == 'rescale_atomic_energy' and follower == 'shift': if k == 'rescale_atomic_energy' and follower == 'shift':
new_state_dict[stct_key] = scaler.shift.clone() new_state_dict[stct_key] = scaler.shift.clone()
elif k == 'rescale_atomic_energy' and follower == 'scale': elif k == 'rescale_atomic_energy' and follower == 'scale':
new_state_dict[stct_key] = scaler.scale.clone() new_state_dict[stct_key] = scaler.scale.clone()
elif follower == 'linear.weight' and ( # append linear layer elif follower == 'linear.weight' and ( # append linear layer
( (
dst_config[KEY.USE_MODAL_NODE_EMBEDDING] dst_config[KEY.USE_MODAL_NODE_EMBEDDING]
and k.endswith('onehot_to_feature_x') and k.endswith('onehot_to_feature_x')
) )
or ( or (
dst_config[KEY.USE_MODAL_SELF_INTER_INTRO] dst_config[KEY.USE_MODAL_SELF_INTER_INTRO]
and k.endswith('self_interaction_1') and k.endswith('self_interaction_1')
) )
or ( or (
dst_config[KEY.USE_MODAL_SELF_INTER_OUTRO] dst_config[KEY.USE_MODAL_SELF_INTER_OUTRO]
and k.endswith('self_interaction_2') and k.endswith('self_interaction_2')
) )
or ( or (
dst_config[KEY.USE_MODAL_OUTPUT_BLOCK] dst_config[KEY.USE_MODAL_OUTPUT_BLOCK]
and k == 'reduce_input_to_hidden' and k == 'reduce_input_to_hidden'
) )
): ):
orig_linear = getattr(orig_model._modules[k], 'linear') orig_linear = getattr(orig_model._modules[k], 'linear')
# assert normalization element # assert normalization element
new_state_dict[stct_key] = _append_modal_weight( new_state_dict[stct_key] = _append_modal_weight(
orig_state_dict, orig_state_dict,
k, k,
orig_linear.irreps_in, orig_linear.irreps_in,
orig_linear.irreps_out, orig_linear.irreps_out,
append_num, append_num,
) )
dst_config['version'] = sevenn.__version__ dst_config['version'] = sevenn.__version__
return new_state_dict return new_state_dict
def get_checkpoint_dict(self) -> dict: def get_checkpoint_dict(self) -> dict:
""" """
Return duplicate of this checkpoint with new hash and time. Return duplicate of this checkpoint with new hash and time.
Convenient for creating variant of the checkpoint Convenient for creating variant of the checkpoint
""" """
return { return {
'config': self.config, 'config': self.config,
'epoch': self.epoch, 'epoch': self.epoch,
'model_state_dict': self.model_state_dict, 'model_state_dict': self.model_state_dict,
'optimizer_state_dict': self.optimizer_state_dict, 'optimizer_state_dict': self.optimizer_state_dict,
'scheduler_state_dict': self.scheduler_state_dict, 'scheduler_state_dict': self.scheduler_state_dict,
'time': datetime.now().strftime('%Y-%m-%d %H:%M'), 'time': datetime.now().strftime('%Y-%m-%d %H:%M'),
'hash': uuid.uuid4().hex, 'hash': uuid.uuid4().hex,
} }
from copy import deepcopy from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import sevenn._keys as KEY import sevenn._keys as KEY
from sevenn.train.loss import LossDefinition from sevenn.train.loss import LossDefinition
from .atom_graph_data import AtomGraphData from .atom_graph_data import AtomGraphData
from .train.optim import loss_dict from .train.optim import loss_dict
_ERROR_TYPES = { _ERROR_TYPES = {
'TotalEnergy': { 'TotalEnergy': {
'name': 'Energy', 'name': 'Energy',
'ref_key': KEY.ENERGY, 'ref_key': KEY.ENERGY,
'pred_key': KEY.PRED_TOTAL_ENERGY, 'pred_key': KEY.PRED_TOTAL_ENERGY,
'unit': 'eV', 'unit': 'eV',
'vdim': 1, 'vdim': 1,
}, },
'Energy': { # by default per-atom for energy 'Energy': { # by default per-atom for energy
'name': 'Energy', 'name': 'Energy',
'ref_key': KEY.ENERGY, 'ref_key': KEY.ENERGY,
'pred_key': KEY.PRED_TOTAL_ENERGY, 'pred_key': KEY.PRED_TOTAL_ENERGY,
'unit': 'eV/atom', 'unit': 'eV/atom',
'per_atom': True, 'per_atom': True,
'vdim': 1, 'vdim': 1,
}, },
'Force': { 'Force': {
'name': 'Force', 'name': 'Force',
'ref_key': KEY.FORCE, 'ref_key': KEY.FORCE,
'pred_key': KEY.PRED_FORCE, 'pred_key': KEY.PRED_FORCE,
'unit': 'eV/Å', 'unit': 'eV/Å',
'vdim': 3, 'vdim': 3,
}, },
'Stress': { 'Stress': {
'name': 'Stress', 'name': 'Stress',
'ref_key': KEY.STRESS, 'ref_key': KEY.STRESS,
'pred_key': KEY.PRED_STRESS, 'pred_key': KEY.PRED_STRESS,
'unit': 'kbar', 'unit': 'kbar',
'coeff': 1602.1766208, 'coeff': 1602.1766208,
'vdim': 6, 'vdim': 6,
}, },
'Stress_GPa': { 'Stress_GPa': {
'name': 'Stress', 'name': 'Stress',
'ref_key': KEY.STRESS, 'ref_key': KEY.STRESS,
'pred_key': KEY.PRED_STRESS, 'pred_key': KEY.PRED_STRESS,
'unit': 'GPa', 'unit': 'GPa',
'coeff': 160.21766208, 'coeff': 160.21766208,
'vdim': 6, 'vdim': 6,
}, },
'TotalLoss': { 'TotalLoss': {
'name': 'TotalLoss', 'name': 'TotalLoss',
'unit': None, 'unit': None,
}, },
} }
def get_err_type(name: str) -> Dict[str, Any]: def get_err_type(name: str) -> Dict[str, Any]:
return deepcopy(_ERROR_TYPES[name]) return deepcopy(_ERROR_TYPES[name])
def _get_loss_function_from_name(loss_functions, name): def _get_loss_function_from_name(loss_functions, name):
for loss_def, w in loss_functions: for loss_def, w in loss_functions:
if loss_def.name.lower() == name.lower(): if loss_def.name.lower() == name.lower():
return loss_def, w return loss_def, w
return None, None return None, None
class AverageNumber: class AverageNumber:
def __init__(self): def __init__(self):
self._sum = 0.0 self._sum = 0.0
self._count = 0 self._count = 0
def update(self, values: torch.Tensor): def update(self, values: torch.Tensor):
self._sum += values.sum().item() self._sum += values.sum().item()
self._count += values.numel() self._count += values.numel()
def _ddp_reduce(self, device): def _ddp_reduce(self, device):
_sum = torch.tensor(self._sum, device=device) _sum = torch.tensor(self._sum, device=device)
_count = torch.tensor(self._count, device=device) _count = torch.tensor(self._count, device=device)
dist.all_reduce(_sum, op=dist.ReduceOp.SUM) dist.all_reduce(_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(_count, op=dist.ReduceOp.SUM) dist.all_reduce(_count, op=dist.ReduceOp.SUM)
self._sum = _sum.item() self._sum = _sum.item()
self._count = _count.item() self._count = _count.item()
def get(self): def get(self):
if self._count == 0: if self._count == 0:
return torch.nan return torch.nan
return self._sum / self._count return self._sum / self._count
class ErrorMetric: class ErrorMetric:
""" """
Base class for error metrics We always average error by # of structures, Base class for error metrics We always average error by # of structures,
and designed to collect errors in the middle of iteration (by AverageNumber) and designed to collect errors in the middle of iteration (by AverageNumber)
""" """
def __init__( def __init__(
self, self,
name: str, name: str,
ref_key: str, ref_key: str,
pred_key: str, pred_key: str,
coeff: float = 1.0, coeff: float = 1.0,
unit: Optional[str] = None, unit: Optional[str] = None,
per_atom: bool = False, per_atom: bool = False,
ignore_unlabeled: bool = True, ignore_unlabeled: bool = True,
**kwargs, **kwargs,
): ):
self.name = name self.name = name
self.unit = unit self.unit = unit
self.coeff = coeff self.coeff = coeff
self.ref_key = ref_key self.ref_key = ref_key
self.pred_key = pred_key self.pred_key = pred_key
self.per_atom = per_atom self.per_atom = per_atom
self.ignore_unlabeled = ignore_unlabeled self.ignore_unlabeled = ignore_unlabeled
self.value = AverageNumber() self.value = AverageNumber()
def update(self, output: AtomGraphData): def update(self, output: AtomGraphData):
raise NotImplementedError raise NotImplementedError
def _retrieve(self, output: AtomGraphData): def _retrieve(self, output: AtomGraphData):
y_ref = output[self.ref_key] * self.coeff y_ref = output[self.ref_key] * self.coeff
y_pred = output[self.pred_key] * self.coeff y_pred = output[self.pred_key] * self.coeff
if self.per_atom: if self.per_atom:
assert y_ref.dim() == 1 and y_pred.dim() == 1 assert y_ref.dim() == 1 and y_pred.dim() == 1
natoms = output[KEY.NUM_ATOMS] natoms = output[KEY.NUM_ATOMS]
y_ref = y_ref / natoms y_ref = y_ref / natoms
y_pred = y_pred / natoms y_pred = y_pred / natoms
if self.ignore_unlabeled: if self.ignore_unlabeled:
unlabelled_idx = torch.isnan(y_ref) unlabelled_idx = torch.isnan(y_ref)
y_ref = y_ref[~unlabelled_idx] y_ref = y_ref[~unlabelled_idx]
y_pred = y_pred[~unlabelled_idx] y_pred = y_pred[~unlabelled_idx]
return y_ref, y_pred return y_ref, y_pred
def ddp_reduce(self, device): def ddp_reduce(self, device):
self.value._ddp_reduce(device) self.value._ddp_reduce(device)
def reset(self): def reset(self):
self.value = AverageNumber() self.value = AverageNumber()
def get(self): def get(self):
return self.value.get() return self.value.get()
def key_str(self, with_unit=True): def key_str(self, with_unit=True):
if self.unit is None or not with_unit: if self.unit is None or not with_unit:
return self.name return self.name
else: else:
return f'{self.name} ({self.unit})' return f'{self.name} ({self.unit})'
def __str__(self): def __str__(self):
return f'{self.key_str()}: {self.value.get():.6f}' return f'{self.key_str()}: {self.value.get():.6f}'
class RMSError(ErrorMetric): class RMSError(ErrorMetric):
""" """
Vector squared error Vector squared error
""" """
def __init__(self, vdim: int = 1, **kwargs): def __init__(self, vdim: int = 1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vdim = vdim self.vdim = vdim
self._se = torch.nn.MSELoss(reduction='none') self._se = torch.nn.MSELoss(reduction='none')
def _square_error(self, y_ref, y_pred, vdim: int): def _square_error(self, y_ref, y_pred, vdim: int):
return self._se(y_ref.view(-1, vdim), y_pred.view(-1, vdim)).sum(dim=1) return self._se(y_ref.view(-1, vdim), y_pred.view(-1, vdim)).sum(dim=1)
def update(self, output: AtomGraphData): def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output) y_ref, y_pred = self._retrieve(output)
se = self._square_error(y_ref, y_pred, self.vdim) se = self._square_error(y_ref, y_pred, self.vdim)
self.value.update(se) self.value.update(se)
def get(self): def get(self):
return self.value.get() ** 0.5 return self.value.get() ** 0.5
class ComponentRMSError(ErrorMetric): class ComponentRMSError(ErrorMetric):
""" """
Ignore vector dim and just average over components Ignore vector dim and just average over components
Results smaller error Results smaller error
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._se = torch.nn.MSELoss(reduction='none') self._se = torch.nn.MSELoss(reduction='none')
def _square_error(self, y_ref, y_pred): def _square_error(self, y_ref, y_pred):
return self._se(y_ref, y_pred) return self._se(y_ref, y_pred)
def update(self, output: AtomGraphData): def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output) y_ref, y_pred = self._retrieve(output)
y_ref = y_ref.view(-1) y_ref = y_ref.view(-1)
y_pred = y_pred.view(-1) y_pred = y_pred.view(-1)
se = self._square_error(y_ref, y_pred) se = self._square_error(y_ref, y_pred)
self.value.update(se) self.value.update(se)
def get(self): def get(self):
return self.value.get() ** 0.5 return self.value.get() ** 0.5
class MAError(ErrorMetric): class MAError(ErrorMetric):
""" """
Average over all component Average over all component
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
def _square_error(self, y_ref, y_pred): def _square_error(self, y_ref, y_pred):
return torch.abs(y_ref - y_pred) return torch.abs(y_ref - y_pred)
def update(self, output: AtomGraphData): def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output) y_ref, y_pred = self._retrieve(output)
y_ref = y_ref.reshape((-1,)) y_ref = y_ref.reshape((-1,))
y_pred = y_pred.reshape((-1,)) y_pred = y_pred.reshape((-1,))
se = self._square_error(y_ref, y_pred) se = self._square_error(y_ref, y_pred)
self.value.update(se) self.value.update(se)
class CustomError(ErrorMetric): class CustomError(ErrorMetric):
""" """
Custom error metric Custom error metric
Args: Args:
func: a function that takes y_ref and y_pred func: a function that takes y_ref and y_pred
and returns a list of errors and returns a list of errors
""" """
def __init__(self, func: Callable, **kwargs): def __init__(self, func: Callable, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.func = func self.func = func
def update(self, output: AtomGraphData): def update(self, output: AtomGraphData):
y_ref, y_pred = self._retrieve(output) y_ref, y_pred = self._retrieve(output)
se = self.func(y_ref, y_pred) if len(y_ref) > 0 else torch.tensor([]) se = self.func(y_ref, y_pred) if len(y_ref) > 0 else torch.tensor([])
self.value.update(se) self.value.update(se)
class LossError(ErrorMetric): class LossError(ErrorMetric):
""" """
Error metric that record loss Error metric that record loss
""" """
def __init__( def __init__(
self, self,
name: str, name: str,
loss_def: LossDefinition, loss_def: LossDefinition,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
name, name,
ignore_unlabeld=loss_def.ignore_unlabeled, ignore_unlabeld=loss_def.ignore_unlabeled,
**kwargs, **kwargs,
) )
self.loss_def = loss_def self.loss_def = loss_def
def update(self, output: AtomGraphData): def update(self, output: AtomGraphData):
loss = self.loss_def.get_loss(output) # type: ignore loss = self.loss_def.get_loss(output) # type: ignore
self.value.update(loss) # type: ignore self.value.update(loss) # type: ignore
class CombinedError(ErrorMetric): class CombinedError(ErrorMetric):
""" """
Combine multiple error metrics with weights Combine multiple error metrics with weights
corresponds to a weighted sum of errors (normally used in loss) corresponds to a weighted sum of errors (normally used in loss)
""" """
def __init__(self, metrics: List[Tuple[ErrorMetric, float]], **kwargs): def __init__(self, metrics: List[Tuple[ErrorMetric, float]], **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.metrics = metrics self.metrics = metrics
assert kwargs['unit'] is None assert kwargs['unit'] is None
def update(self, output: AtomGraphData): def update(self, output: AtomGraphData):
for metric, _ in self.metrics: for metric, _ in self.metrics:
metric.update(output) metric.update(output)
def reset(self): def reset(self):
for metric, _ in self.metrics: for metric, _ in self.metrics:
metric.reset() metric.reset()
def ddp_reduce(self, device): # override def ddp_reduce(self, device): # override
for metric, _ in self.metrics: for metric, _ in self.metrics:
metric.value._ddp_reduce(device) metric.value._ddp_reduce(device)
def get(self): def get(self):
val = 0.0 val = 0.0
for metric, weight in self.metrics: for metric, weight in self.metrics:
val += metric.get() * weight val += metric.get() * weight
return val return val
class ErrorRecorder: class ErrorRecorder:
""" """
record errors of a model record errors of a model
""" """
METRIC_DICT = { METRIC_DICT = {
'RMSE': RMSError, 'RMSE': RMSError,
'ComponentRMSE': ComponentRMSError, 'ComponentRMSE': ComponentRMSError,
'MAE': MAError, 'MAE': MAError,
'Loss': LossError, 'Loss': LossError,
} }
def __init__(self, metrics: List[ErrorMetric]): def __init__(self, metrics: List[ErrorMetric]):
self.history = [] self.history = []
self.metrics = metrics self.metrics = metrics
def _update(self, output: AtomGraphData): def _update(self, output: AtomGraphData):
for metric in self.metrics: for metric in self.metrics:
metric.update(output) metric.update(output)
def update(self, output: AtomGraphData, no_grad=True): def update(self, output: AtomGraphData, no_grad=True):
if no_grad: if no_grad:
with torch.no_grad(): with torch.no_grad():
self._update(output) self._update(output)
else: else:
self._update(output) self._update(output)
def get_metric_dict(self, with_unit=True): def get_metric_dict(self, with_unit=True):
return {metric.key_str(with_unit): metric.get() for metric in self.metrics} return {metric.key_str(with_unit): metric.get() for metric in self.metrics}
def get_current(self): def get_current(self):
dct = {} dct = {}
for metric in self.metrics: for metric in self.metrics:
dct[metric.name] = { dct[metric.name] = {
'value': metric.get(), 'value': metric.get(),
'unit': metric.unit, 'unit': metric.unit,
'ref_key': metric.ref_key, 'ref_key': metric.ref_key,
'pred_key': metric.pred_key, 'pred_key': metric.pred_key,
} }
return dct return dct
def get_dct(self, prefix=''): def get_dct(self, prefix=''):
dct = {} dct = {}
if prefix.endswith('_') is False and prefix != '': if prefix.endswith('_') is False and prefix != '':
prefix = prefix + '_' prefix = prefix + '_'
for metric in self.metrics: for metric in self.metrics:
dct[f'{prefix}{metric.name}'] = f'{metric.get():6f}' dct[f'{prefix}{metric.name}'] = f'{metric.get():6f}'
return dct return dct
def get_key_str(self, name: str): def get_key_str(self, name: str):
for metric in self.metrics: for metric in self.metrics:
if name == metric.name: if name == metric.name:
return metric.key_str() return metric.key_str()
return None return None
def epoch_forward(self): def epoch_forward(self):
self.history.append(self.get_current()) self.history.append(self.get_current())
pretty = self.get_metric_dict(with_unit=True) pretty = self.get_metric_dict(with_unit=True)
for metric in self.metrics: for metric in self.metrics:
metric.reset() metric.reset()
return pretty # for print return pretty # for print
@staticmethod @staticmethod
def init_total_loss_metric( def init_total_loss_metric(
config, config,
criteria: Optional[Callable] = None, criteria: Optional[Callable] = None,
loss_functions: Optional[List[Tuple[LossDefinition, float]]] = None, loss_functions: Optional[List[Tuple[LossDefinition, float]]] = None,
): ):
if criteria is None and loss_functions is None: if criteria is None and loss_functions is None:
raise ValueError('both criteria and loss functions not given') raise ValueError('both criteria and loss functions not given')
is_stress = config[KEY.IS_TRAIN_STRESS] is_stress = config[KEY.IS_TRAIN_STRESS]
metrics = [] metrics = []
if criteria is not None: if criteria is not None:
energy_metric = CustomError(criteria, **get_err_type('Energy')) energy_metric = CustomError(criteria, **get_err_type('Energy'))
metrics.append((energy_metric, 1)) metrics.append((energy_metric, 1))
force_metric = CustomError(criteria, **get_err_type('Force')) force_metric = CustomError(criteria, **get_err_type('Force'))
metrics.append((force_metric, config[KEY.FORCE_WEIGHT])) metrics.append((force_metric, config[KEY.FORCE_WEIGHT]))
if is_stress: if is_stress:
stress_metric = CustomError(criteria, **get_err_type('Stress')) stress_metric = CustomError(criteria, **get_err_type('Stress'))
metrics.append((stress_metric, config[KEY.STRESS_WEIGHT])) metrics.append((stress_metric, config[KEY.STRESS_WEIGHT]))
else: # TODO: this is hard-coded else: # TODO: this is hard-coded
for efs in ['Energy', 'Force', 'Stress']: for efs in ['Energy', 'Force', 'Stress']:
if efs == 'Stress' and not is_stress: if efs == 'Stress' and not is_stress:
continue continue
lf, w = _get_loss_function_from_name(loss_functions, efs) lf, w = _get_loss_function_from_name(loss_functions, efs)
if lf is None: if lf is None:
raise ValueError(f'{efs} not found from loss_functions') raise ValueError(f'{efs} not found from loss_functions')
metric = LossError(loss_def=lf, **get_err_type(efs)) metric = LossError(loss_def=lf, **get_err_type(efs))
metrics.append((metric, w)) metrics.append((metric, w))
total_loss_metric = CombinedError( total_loss_metric = CombinedError(
metrics, name='TotalLoss', unit=None, ref_key=None, pred_key=None metrics, name='TotalLoss', unit=None, ref_key=None, pred_key=None
) )
return total_loss_metric return total_loss_metric
@staticmethod @staticmethod
def from_config(config: dict, loss_functions=None): def from_config(config: dict, loss_functions=None):
loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()] loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()]
loss_param = config.get(KEY.LOSS_PARAM, {}) loss_param = config.get(KEY.LOSS_PARAM, {})
criteria = loss_cls(**loss_param) if loss_functions is None else None criteria = loss_cls(**loss_param) if loss_functions is None else None
err_config = config.get(KEY.ERROR_RECORD, False) err_config = config.get(KEY.ERROR_RECORD, False)
if not err_config: if not err_config:
raise ValueError( raise ValueError(
'No error_record config found. Consider util.get_error_recorder' 'No error_record config found. Consider util.get_error_recorder'
) )
err_config_n = [] err_config_n = []
if not config.get(KEY.IS_TRAIN_STRESS, True): if not config.get(KEY.IS_TRAIN_STRESS, True):
for err_type, metric_name in err_config: for err_type, metric_name in err_config:
if 'Stress' in err_type: if 'Stress' in err_type:
continue continue
err_config_n.append((err_type, metric_name)) err_config_n.append((err_type, metric_name))
err_config = err_config_n err_config = err_config_n
err_metrics = [] err_metrics = []
for err_type, metric_name in err_config: for err_type, metric_name in err_config:
metric_kwargs = get_err_type(err_type) metric_kwargs = get_err_type(err_type)
if err_type == 'TotalLoss': # special case if err_type == 'TotalLoss': # special case
err_metrics.append( err_metrics.append(
ErrorRecorder.init_total_loss_metric( ErrorRecorder.init_total_loss_metric(
config, criteria, loss_functions config, criteria, loss_functions
) )
) )
continue continue
metric_cls = ErrorRecorder.METRIC_DICT[metric_name] metric_cls = ErrorRecorder.METRIC_DICT[metric_name]
assert isinstance(metric_kwargs['name'], str) assert isinstance(metric_kwargs['name'], str)
if metric_name == 'Loss': if metric_name == 'Loss':
if loss_functions is not None: if loss_functions is not None:
metric_cls = LossError metric_cls = LossError
metric_kwargs['loss_def'], _ = _get_loss_function_from_name( metric_kwargs['loss_def'], _ = _get_loss_function_from_name(
loss_functions, metric_kwargs['name'] loss_functions, metric_kwargs['name']
) )
else: else:
metric_cls = CustomError metric_cls = CustomError
metric_kwargs['func'] = criteria metric_kwargs['func'] = criteria
metric_kwargs.pop('unit', None) metric_kwargs.pop('unit', None)
metric_kwargs['name'] += f'_{metric_name}' metric_kwargs['name'] += f'_{metric_name}'
err_metrics.append(metric_cls(**metric_kwargs)) err_metrics.append(metric_cls(**metric_kwargs))
return ErrorRecorder(err_metrics) return ErrorRecorder(err_metrics)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment